Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 110 additions & 82 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import shutil
import sys
import warnings
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from safetensors import safe_open
from transformers import (
AutoConfig,
AutoModel,
Expand Down Expand Up @@ -315,100 +316,127 @@ def get_processor(
return None


def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.
def get_inlined_mtp_prefixes(config: Any) -> list[str]:
"""Turn an HF config into the list of state-dict prefixes for inlined-MTP layers."""
# ``or 0``: some configs set num_nextn_predict_layers=None rather than omit it.
num_nextn = int(getattr(config, "num_nextn_predict_layers", 0) or 0)
if not num_nextn:
return []
num_hidden = config.num_hidden_layers
return [f"model.layers.{i}" for i in range(num_hidden, num_hidden + num_nextn)]

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.

This function detects such cases and explicitly loads the missing weights.

Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory
def _keys_to_prefixes(keys: Iterable[str]) -> set[str]:
"""Invert separate-file MTP keys into the prefixes the exporter needs for exclude_modules.
``"mtp.fc.weight"`` → ``{"mtp"}``; ``"mtp.layers.0.q_proj.weight"`` →
``{"mtp", "mtp.layers.0"}``. Caller must filter out inlined keys; otherwise
``"model.layers.78.eh_proj.weight"`` would emit ``"model"`` as a prefix.
"""
prefixes: set[str] = set()
for key in keys:
parts = key.split(".")
if parts:
prefixes.add(parts[0])
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefixes.add(".".join(parts[: i + 2]))
break
return prefixes


def _load_tensors_matching(
model_dir: Path, predicate: Callable[[str], bool]
) -> dict[str, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor: _keys_to_prefixes is only correct for separate-file keys — if fed an inlined key like "model.layers.78.eh_proj.weight" it would emit "model" as a top-level prefix, which is not what the exporter wants. The current caller filters via separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)] so this is safe today; worth a one-line note in the docstring ("caller must filter inlined keys first") to prevent future misuse.

"""Stream tensors satisfying ``predicate(key)`` from every safetensors
source in ``model_dir`` (indexed shards + standalone files, each opened
at most once).
"""
tensors: dict[str, torch.Tensor] = {}
seen_shards: set[str] = set()

Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
index_file = model_dir / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
weight_map = json.load(f)["weight_map"]
per_shard: dict[str, list[str]] = {}
for key, shard_name in weight_map.items():
if predicate(key):
per_shard.setdefault(shard_name, []).append(key)
for shard_name, keys in per_shard.items():
seen_shards.add(shard_name)
with safe_open(str(model_dir / shard_name), framework="pt", device="cpu") as f:
for k in keys:
tensors[k] = f.get_tensor(k)

for shard in sorted(model_dir.glob("*.safetensors")):
if shard.name in seen_shards:
continue
with safe_open(str(shard), framework="pt", device="cpu") as f:
for k in f.keys(): # noqa: SIM118 - safe_open is not iterable
if predicate(k):
tensors[k] = f.get_tensor(k)
return tensors


def _apply_to_model_state_dict(
model: torch.nn.Module, tensors: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Load tensors with a slot in ``model.state_dict()`` in-place; return the
rest as orphans for ``extra_state_dict``.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"
model_state = model.state_dict()
in_state_dict = {k: v for k, v in tensors.items() if k in model_state}
out_state_dict = {k: v for k, v in tensors.items() if k not in model_state}
if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
return out_state_dict

if not index_file.exists():
return [], {}

# Load the index to find all referenced safetensors files
index = json.load(open(index_file))
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)
def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Detect and load MTP weights. Support matrix:

Convention Architectures On-disk shape
------------- ------------------------ -------------------------------
inlined GLM-5.1 (``GlmMoeDsa``), ``model.layers.{N}.*``
DeepSeek-V3
separate-file GLM-4.7 standalone ``mtp.safetensors``
separate-file Qwen3-Next indexed ``mtp.*`` tail shard

Inlined ``N`` in ``[num_hidden, num_hidden + num_nextn_predict_layers)``;
may be orphaned at ``from_pretrained`` time if the HF class only builds
``num_hidden`` decoders.

Returns ``(prefixes, not_in_state_dict)``: ``prefixes`` populates
``quantization_config.exclude_modules``; ``not_in_state_dict`` is fed to
``export_hf_checkpoint(extra_state_dict=...)``.
"""
model_dir = Path(model_path)

Comment thread
Fridah-nv marked this conversation as resolved.
if not mtp_weight_map:
return [], {}
inlined_prefixes = set(get_inlined_mtp_prefixes(model.config))
inlined_tuple = tuple(p + "." for p in inlined_prefixes)

def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
# Capture the top-level MTP module prefix (e.g., "mtp" from "mtp.fc.weight")
# so that non-layer MTP weights like mtp.fc, mtp.norm are also excluded
if parts:
mtp_layer_prefixes.add(parts[0])
# Also capture specific layer prefixes (e.g., "mtp.layers.0")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break

return mtp_layer_prefixes

# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)

# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0
# Combined predicate covering both conventions in one pass.
def predicate(key: str) -> bool:
return key.startswith(inlined_tuple) or "mtp" in key

not_in_state_dict = {}
tensors = _load_tensors_matching(model_dir, predicate)
Comment thread
Fridah-nv marked this conversation as resolved.
if not tensors:
return [], {}

for filename, mtp_keys in mtp_weight_map.items():
filepath = model_path / filename
if not filepath.exists():
continue
separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)]
prefixes = inlined_prefixes | _keys_to_prefixes(separate_keys)

print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}

if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)

if total_loaded > 0:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)
not_in_state_dict = _apply_to_model_state_dict(model, tensors)

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
print(
f"✓ Detected {len(tensors)} MTP tensors under {sorted(prefixes)} "
f"(loaded into model: {len(tensors) - len(not_in_state_dict)}, "
f"orphaned: {len(not_in_state_dict)})"
)

return list(mtp_layer_prefixes), not_in_state_dict
return sorted(prefixes), not_in_state_dict


def get_dtype(dtype):
Expand Down
30 changes: 30 additions & 0 deletions tests/_test_utils/examples/llm_ptq_example_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Re-export ``examples/llm_ptq/example_utils`` so tests can import it via
``from _test_utils.examples.llm_ptq_example_utils import example_utils``
without per-file ``sys.path`` shims.
"""

import sys

from _test_utils.examples.run_command import MODELOPT_ROOT

_LLM_PTQ_DIR = MODELOPT_ROOT / "examples" / "llm_ptq"
if str(_LLM_PTQ_DIR) not in sys.path:
sys.path.insert(0, str(_LLM_PTQ_DIR))

import example_utils

__all__ = ["example_utils"]
Loading
Loading