Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 133 additions & 73 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -315,98 +316,157 @@ def get_processor(
return None


def _load_inlined_mtp_tensors(model_path: Path, mtp_prefixes: set[str]) -> dict[str, torch.Tensor]:
"""Stream tensors whose keys start with any ``{prefix}.`` from on-disk shards.

Walks ``model.safetensors.index.json`` when present, else falls back to a
single ``model.safetensors`` file. Returns an empty dict if no matching
keys are found.
"""

def _matches(key: str) -> bool:
return any(key.startswith(p + ".") for p in mtp_prefixes)

Comment thread
Fridah-nv marked this conversation as resolved.
Outdated
tensors: dict[str, torch.Tensor] = {}
index_file = model_path / "model.safetensors.index.json"
if index_file.exists():
weight_map = json.load(open(index_file))["weight_map"]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
per_shard: dict[str, list[str]] = {}
for key, shard in weight_map.items():
if _matches(key):
per_shard.setdefault(shard, []).append(key)
for shard, keys in per_shard.items():
with safe_open(str(model_path / shard), framework="pt", device="cpu") as f:
for k in keys:
tensors[k] = f.get_tensor(k)
else:
single = model_path / "model.safetensors"
if single.exists():
with safe_open(str(single), framework="pt", device="cpu") as f:
# safe_open is not a dict; ``.keys()`` is the public listing API.
for k in f.keys(): # noqa: SIM118
if _matches(k):
tensors[k] = f.get_tensor(k)
return tensors


def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.
Detects MTP layers under two on-disk conventions:

Comment thread
Fridah-nv marked this conversation as resolved.
This function detects such cases and explicitly loads the missing weights.
1. Separate-file: weights live in a non-standard safetensors file (e.g.,
``mtp.safetensors``) with keys prefixed by ``mtp``.
2. Inlined: weights live in the main shards under
``model.layers[num_hidden : num_hidden + num_nextn_predict_layers]``
(DeepSeek-V3, GLM-5.1 ``GlmMoeDsa``, GLM-4.7).

Whether the HF model class actually instantiates the MTP module varies:
DeepSeek-V3's modeling code adds the extra decoder layers, while
``GlmMoeDsaModel`` in transformers >=5.7 builds only ``num_hidden`` layers
and leaves the inlined-MTP keys orphaned. To keep both paths correct, we
always read the tensors off disk here and split them: any key that maps
to a parameter in ``model.state_dict()`` is loaded into the model;
the remainder is returned as ``not_in_state_dict`` for the exporter to
merge in via ``extra_state_dict``.

Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory

Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
List of layer prefixes that should be excluded from quantization
(e.g., ``"mtp.layers.0"`` or ``"model.layers.78"``). Empty if no MTP
layers were detected.
Dictionary of MTP weights that have no slot in the model's state
dict; ``export_hf_checkpoint`` merges these via ``extra_state_dict``.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
return [], {}

# Load the index to find all referenced safetensors files
index = json.load(open(index_file))
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)

if not mtp_weight_map:
return [], {}

def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
# Capture the top-level MTP module prefix (e.g., "mtp" from "mtp.fc.weight")
# so that non-layer MTP weights like mtp.fc, mtp.norm are also excluded
if parts:
mtp_layer_prefixes.add(parts[0])
# Also capture specific layer prefixes (e.g., "mtp.layers.0")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break

return mtp_layer_prefixes

# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)

# Check which non-standard files exist and have missing weights
mtp_layer_prefixes: set[str] = set()
not_in_state_dict: dict[str, torch.Tensor] = {}
model_dir = Path(model_path)
model_state = model.state_dict()
total_loaded = 0

not_in_state_dict = {}
# Inlined-MTP convention: keys ``model.layers.{i}.*`` for
# ``i in [num_hidden, num_hidden + num_nextn)``.
cfg = model.config
num_nextn = int(getattr(cfg, "num_nextn_predict_layers", 0))
num_hidden = cfg.num_hidden_layers
if num_nextn:
inlined_prefixes = {f"model.layers.{i}" for i in range(num_hidden, num_hidden + num_nextn)}
inlined_tensors = _load_inlined_mtp_tensors(model_dir, inlined_prefixes)
if inlined_tensors:
mtp_layer_prefixes |= inlined_prefixes
in_state_dict = {k: v for k, v in inlined_tensors.items() if k in model_state}
not_in_state_dict |= {k: v for k, v in inlined_tensors.items() if k not in model_state}
if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
print(
f"✓ Detected {len(inlined_tensors)} inlined MTP tensors under "
f"{sorted(inlined_prefixes)} "
f"(loaded into model: {len(in_state_dict)}, orphaned: {len(not_in_state_dict)})"
)

for filename, mtp_keys in mtp_weight_map.items():
filepath = model_path / filename
if not filepath.exists():
continue
index_file = model_dir / "model.safetensors.index.json"

if index_file.exists():
# Separate-file MTP detection via safetensors index.
index = json.load(open(index_file))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map: dict[str, list[str]] = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)

if mtp_weight_map:

def _extract_layer_prefixes(keys):
prefixes = set()
for key in keys:
parts = key.split(".")
# Capture the top-level MTP module prefix (e.g., "mtp" from "mtp.fc.weight")
# so that non-layer MTP weights like mtp.fc, mtp.norm are also excluded
if parts:
prefixes.add(parts[0])
# Also capture specific layer prefixes (e.g., "mtp.layers.0")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefixes.add(".".join(parts[: i + 2]))
break
return prefixes

mtp_keys_flat = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes |= _extract_layer_prefixes(mtp_keys_flat)

# Load any weights missing from model.state_dict from the non-standard files.
total_loaded = 0
for filename, mtp_keys in mtp_weight_map.items():
filepath = model_dir / filename
if not filepath.exists():
continue

print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}

if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)

if total_loaded > 0:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}

if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)

if total_loaded > 0:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
print(f"✓ Detected MTP layers to exclude from quantization: {sorted(mtp_layer_prefixes)}")

return list(mtp_layer_prefixes), not_in_state_dict

Expand Down
Loading