Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 140 additions & 74 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import shutil
import sys
import warnings
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from safetensors import safe_open
from transformers import (
AutoConfig,
AutoModel,
Expand Down Expand Up @@ -315,100 +316,165 @@ def get_processor(
return None


def get_inlined_mtp_prefixes(config: Any) -> list[str]:
"""Pure: HF config → state-dict-key prefixes for inlined MTP layers.

Inlined-MTP convention (DeepSeek-V3, GLM-5.1 ``GlmMoeDsa``, GLM-4.7):
MTP tensors live under ``model.layers[i]`` for
``i in [num_hidden, num_hidden + num_nextn_predict_layers)``. Returns
``[]`` when the config does not declare any MTP layers.
"""
# ``or 0`` guards against configs that set ``num_nextn_predict_layers``
# explicitly to ``None`` rather than omitting the field.
num_nextn = int(getattr(config, "num_nextn_predict_layers", 0) or 0)
if not num_nextn:
return []
num_hidden = config.num_hidden_layers
return [f"model.layers.{i}" for i in range(num_hidden, num_hidden + num_nextn)]


def _keys_to_prefixes(keys: Iterable[str]) -> set[str]:
"""Pure: separate-file MTP keys → state-dict prefixes for ``exclude_modules``.

For each key, extracts:
- the top-level module prefix (e.g. ``"mtp"`` from ``"mtp.fc.weight"``)
so non-layer MTP weights like ``mtp.fc`` and ``mtp.norm`` are excluded.
- the specific layer prefix (e.g. ``"mtp.layers.0"`` from
``"mtp.layers.0.q_proj.weight"``).
"""
prefixes: set[str] = set()
for key in keys:
parts = key.split(".")
if parts:
prefixes.add(parts[0])
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefixes.add(".".join(parts[: i + 2]))
break
return prefixes


def _load_tensors_matching(
model_dir: Path, predicate: Callable[[str, str | None], bool]
) -> dict[str, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor: _keys_to_prefixes is only correct for separate-file keys — if fed an inlined key like "model.layers.78.eh_proj.weight" it would emit "model" as a top-level prefix, which is not what the exporter wants. The current caller filters via separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)] so this is safe today; worth a one-line note in the docstring ("caller must filter inlined keys first") to prevent future misuse.

"""Stream tensors satisfying ``predicate`` from every safetensors source
in ``model_dir`` via ``safe_open``.

Sources walked (each at most once):
1. Sharded layout: shards referenced by ``model.safetensors.index.json``.
The predicate sees ``(key, shard_filename)``.
2. Standalone files: any ``*.safetensors`` not referenced by the index
(including a standalone ``model.safetensors`` when no index exists and
legacy auxiliary files like ``mtp.safetensors``). The predicate sees
``(key, file_name)``.

Returns an empty dict if no tensor matches.
"""
tensors: dict[str, torch.Tensor] = {}
seen_shards: set[str] = set()

index_file = model_dir / "model.safetensors.index.json"
if index_file.exists():
with open(index_file) as f:
weight_map = json.load(f)["weight_map"]
per_shard: dict[str, list[str]] = {}
for key, shard_name in weight_map.items():
if predicate(key, shard_name):
per_shard.setdefault(shard_name, []).append(key)
for shard_name, keys in per_shard.items():
seen_shards.add(shard_name)
with safe_open(str(model_dir / shard_name), framework="pt", device="cpu") as f:
for k in keys:
tensors[k] = f.get_tensor(k)

for shard in sorted(model_dir.glob("*.safetensors")):
if shard.name in seen_shards:
continue
with safe_open(str(shard), framework="pt", device="cpu") as f:
for k in f.keys(): # noqa: SIM118 - safe_open is not iterable
if predicate(k, shard.name):
tensors[k] = f.get_tensor(k)
return tensors


def _apply_to_model_state_dict(
model: torch.nn.Module, tensors: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Split ``tensors`` by whether each key is in ``model.state_dict()``.
Load the matching keys into the model in-place; return the remainder
(orphans) so the exporter can route them through ``extra_state_dict``.
"""
model_state = model.state_dict()
in_state_dict = {k: v for k, v in tensors.items() if k in model_state}
out_state_dict = {k: v for k, v in tensors.items() if k not in model_state}
if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
return out_state_dict


def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.
Detects MTP layers under two on-disk conventions:

1. Separate-file: weights live in a non-standard safetensors file (e.g.,
``mtp.safetensors``) with keys prefixed by ``mtp``.
2. Inlined: weights live in the main shards under
``model.layers[num_hidden : num_hidden + num_nextn_predict_layers]``
(DeepSeek-V3, GLM-5.1 ``GlmMoeDsa``, GLM-4.7).

This function detects such cases and explicitly loads the missing weights.
Whether the HF model class actually instantiates the MTP module varies:
DeepSeek-V3's modeling code adds the extra decoder layers, while
``GlmMoeDsaModel`` in transformers >=5.7 builds only ``num_hidden`` layers
and leaves the inlined-MTP keys orphaned. To keep both paths correct, we
always read the tensors off disk here and split them: any key that maps
to a parameter in ``model.state_dict()`` is loaded into the model;
the remainder is returned as ``not_in_state_dict`` for the exporter to
merge in via ``extra_state_dict``.

Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory

Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
List of layer prefixes that should be excluded from quantization
(e.g., ``"mtp.layers.0"`` or ``"model.layers.78"``). Empty if no MTP
layers were detected.
Dictionary of MTP weights that have no slot in the model's state
dict; ``export_hf_checkpoint`` merges these via ``extra_state_dict``.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
return [], {}

# Load the index to find all referenced safetensors files
index = json.load(open(index_file))
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)

if not mtp_weight_map:
return [], {}

def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
# Capture the top-level MTP module prefix (e.g., "mtp" from "mtp.fc.weight")
# so that non-layer MTP weights like mtp.fc, mtp.norm are also excluded
if parts:
mtp_layer_prefixes.add(parts[0])
# Also capture specific layer prefixes (e.g., "mtp.layers.0")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break

return mtp_layer_prefixes

# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)

# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0
model_dir = Path(model_path)

Comment thread
Fridah-nv marked this conversation as resolved.
not_in_state_dict = {}
inlined_prefixes = set(get_inlined_mtp_prefixes(model.config))
inlined_tuple = tuple(p + "." for p in inlined_prefixes)

for filename, mtp_keys in mtp_weight_map.items():
filepath = model_path / filename
if not filepath.exists():
continue
def predicate(key: str, shard_name: str | None) -> bool:
# Inlined: key prefix matches an MTP layer index from the config.
if inlined_tuple and key.startswith(inlined_tuple):
return True
# Separate-file legacy: ``"mtp"`` in the key or in the shard filename
# (e.g. ``mtp.safetensors`` referenced or sitting alongside the index).
return "mtp" in key or (shard_name is not None and "mtp" in shard_name)

print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}
tensors = _load_tensors_matching(model_dir, predicate)
Comment thread
Fridah-nv marked this conversation as resolved.

if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)
# Anything we loaded that isn't covered by an inlined prefix came from the
# separate-file convention; derive its prefixes from the keys themselves.
separate_keys = [k for k in tensors if not k.startswith(inlined_tuple)]
prefixes = inlined_prefixes | _keys_to_prefixes(separate_keys) if tensors else set()

if total_loaded > 0:
not_in_state_dict = _apply_to_model_state_dict(model, tensors)

if prefixes:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
f"✓ Detected {len(tensors)} MTP tensors under {sorted(prefixes)} "
f"(loaded into model: {len(tensors) - len(not_in_state_dict)}, "
f"orphaned: {len(not_in_state_dict)})"
)

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")

return list(mtp_layer_prefixes), not_in_state_dict
return sorted(prefixes), not_in_state_dict


def get_dtype(dtype):
Expand Down
36 changes: 36 additions & 0 deletions tests/_test_utils/examples/llm_ptq_example_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Importer for ``examples/llm_ptq/example_utils.py``.

The module lives next to the example script (not inside the ``modelopt``
package), so we add ``examples/llm_ptq/`` to ``sys.path`` once here and
re-export the module. Tests then import it as::

from _test_utils.examples.llm_ptq_example_utils import example_utils

instead of repeating ``sys.path`` manipulation in every test file.
"""

import sys

from _test_utils.examples.run_command import MODELOPT_ROOT

_LLM_PTQ_DIR = MODELOPT_ROOT / "examples" / "llm_ptq"
if str(_LLM_PTQ_DIR) not in sys.path:
sys.path.insert(0, str(_LLM_PTQ_DIR))

import example_utils

__all__ = ["example_utils"]
Loading
Loading