Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changelog
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection <https://huggingface.co/collections/nvidia/nemotron-post-training-v3>`_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback.
- The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically.
- ``examples/llm_ptq/hf_ptq.py`` now derives its ``--qformat`` and ``--kv_cache_qformat`` CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying a hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` table. Adding a new preset YAML makes it available on the CLI with no script change. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes).
- Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units.

**Bug Fixes**

Expand Down
271 changes: 176 additions & 95 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import random
import time
import warnings
from collections.abc import Iterator, Mapping
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -55,7 +56,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.recipe import ModelOptPTQRecipe, load_config, load_recipe
from modelopt.torch.export import (
export_hf_checkpoint,
export_hf_vllm_fq_checkpoint,
Expand All @@ -66,7 +67,12 @@
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT
from modelopt.torch.quantization.config import (
QuantizeConfig,
_default_disabled_quantizer_cfg,
need_calibration,
)
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.speculative.eagle.utils import (
Expand All @@ -86,56 +92,167 @@
RAND_SEED = 1234


def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"""Set use_constant_amax on KV cache quantizers.
# Preset directories under modelopt_recipes/ that back the --qformat and
# --kv_cache_qformat CLI vocabularies. Each ``*.yaml`` file in these directories is
# automatically discovered and exposed as a valid CLI value via _PresetCfgChoices,
# so no code change in this script is required when a YAML is added or removed.
# This is deliberate: every preset YAML is CLI-exposed, there is no separate
# allow-list — the directory listing is the policy.
#
# That said, prefer NOT to add new YAMLs to these preset directories either. The
# long-term direction is to retire --qformat / --kv_cache_qformat entirely in favour
# of --recipe, which accepts a full PTQ recipe (see modelopt_recipes/general/ptq/
# and modelopt/recipe/). New quantization configurations should be authored as
# recipes, not as preset entries.
_QFORMAT_PRESET_DIR = "configs/ptq/presets/model"
_KV_QFORMAT_PRESET_DIR = "configs/ptq/presets/kv"

# Backward-compat short names → canonical preset basename. These aliases predate the
# YAML-driven discovery below and remain accepted so existing scripts keep working.
#
# DO NOT add new entries here. New quantization formats must be exposed via their YAML
# basename under modelopt_recipes/configs/ptq/presets/model/ — the directory listing is
# the canonical CLI vocabulary. This table exists solely to keep pre-existing short
# names (and the scripts/docs that hardcode them) working through deprecation, and
# should only ever shrink.
_QFORMAT_ALIASES: dict[str, str] = {
"int8_sq": "int8_smoothquant",
"int8_wo": "int8_weight_only",
"w4a8_awq": "w4a8_awq_beta",
"nvfp4_awq": "nvfp4_awq_lite",
"nvfp4_mse": "nvfp4_w4a4_weight_mse_fp8_sweep",
"nvfp4_local_hessian": "nvfp4_w4a4_weight_local_hessian",
"fp8_pb_wo": "fp8_2d_blockwise_weight_only",
"fp8_pc_pt": "fp8_per_channel_per_token",
}

# Sentinel value for ``--kv_cache_qformat`` meaning "no KV cache quantization".
_KV_NONE = "none"

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.

def _kv_cfg_uses_constant_amax(kv_quant_cfg: list[dict[str, Any]]) -> bool:
"""Return True if this KV cfg pins ``use_constant_amax`` on the bmm quantizer.

Cast-style KV presets (e.g. ``fp8_cast`` / ``nvfp4_cast``) set
``use_constant_amax: true`` on the ``*[kv]_bmm_quantizer`` entry; that flag
means there is no data-driven calibration to run, so callers should skip
the KV-only calibration pass. Detect the property from the YAML contents
rather than from the preset name so new cast-style presets work
automatically.
"""
for i, entry in enumerate(quant_cfg):
for entry in kv_quant_cfg:
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
break


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a16_nvfp4": mtq.W4A16_NVFP4_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
"nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
"nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
}
return bool(cfg.get("use_constant_amax"))
return False

KV_QUANT_CFG_CHOICES = {
"none": "none",
"fp8_cast": "FP8_KV_CFG",
"fp8": "FP8_KV_CFG",
"fp8_affine": "FP8_AFFINE_KV_CFG",
"nvfp4_cast": "NVFP4_KV_CFG",
"nvfp4": "NVFP4_KV_CFG",
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
}

# Formats that use use_constant_amax (no calibration needed).
_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"}
class _PresetCfgChoices(Mapping[str, dict[str, Any]]):
"""Lazy mapping of qformat names → quant_cfg dicts loaded from preset YAMLs.

Iterates the YAML files in ``modelopt_recipes/<subdir>/`` to populate the set
of available qformat names; the supplied ``aliases`` table maps additional
short names onto canonical preset basenames. Loading happens on first access
and is memoised so repeated lookups are cheap.
"""

def __init__(self, subdir: str, aliases: Mapping[str, str] | None = None):
self._subdir = subdir
self._aliases: dict[str, str] = dict(aliases or {})
self._presets: set[str] = set()
for entry in BUILTIN_CONFIG_ROOT.joinpath(subdir).iterdir():
name = entry.name
if name.endswith((".yaml", ".yml")):
self._presets.add(name.rsplit(".", 1)[0])
# Aliases that point at non-existent presets would silently fail at access
# time; surface this at import instead.
for alias, target in self._aliases.items():
if target not in self._presets:
raise ValueError(
f"Alias {alias!r} points at preset {target!r} which is not present "
f"under modelopt_recipes/{subdir}/."
)
self._cache: dict[str, dict[str, Any]] = {}

def _canonical(self, key: str) -> str | None:
if key in self._presets:
return key
return self._aliases.get(key)

def __contains__(self, key: object) -> bool:
return isinstance(key, str) and self._canonical(key) is not None

def __getitem__(self, key: str) -> dict[str, Any]:
canon = self._canonical(key)
if canon is None:
raise KeyError(key)
if canon not in self._cache:
self._cache[canon] = load_config(
f"{self._subdir}/{canon}", schema_type=QuantizeConfig
).model_dump(exclude_unset=True)
# Deepcopy on retrieval so callers can freely mutate the returned config
# (append per-model overrides, etc.) without poisoning the cached entry.
return copy.deepcopy(self._cache[canon])

def __iter__(self) -> Iterator[str]:
yield from sorted(self._presets | set(self._aliases))

def __len__(self) -> int:
return len(self._presets) + len(self._aliases)


QUANT_CFG_CHOICES: Mapping[str, dict[str, Any]] = _PresetCfgChoices(
_QFORMAT_PRESET_DIR, _QFORMAT_ALIASES
)
KV_QUANT_CFG_CHOICES: Mapping[str, dict[str, Any]] = _PresetCfgChoices(_KV_QFORMAT_PRESET_DIR)

# Guard against a future ``none.yaml`` (or alias) colliding with the disable sentinel:
# argparse would silently allow both, but the runtime branch on ``!= _KV_NONE`` would
# become ambiguous and the user couldn't reach the real preset.
assert _KV_NONE not in KV_QUANT_CFG_CHOICES, (
f"_KV_NONE sentinel {_KV_NONE!r} collides with a KV preset; rename the preset."
)

# Formats supported by mtq.auto_quantize unified-checkpoint export.
#
# This stays hardcoded — and intentionally not derived from the preset directory —
# because auto_quantize compatibility is a property of the export path (the unified
# HF checkpoint writer, TRT-LLM consumer constraints, layer-wise mixing rules), not
# of the YAML itself. A preset can exist and be valid for plain PTQ while not being
# safe to mix into an auto_quantize search. Update this set when adding/removing a
# format from auto_quantize support.
_AUTO_QUANTIZE_QFORMATS: frozenset[str] = frozenset(
{
"fp8",
"int8_smoothquant",
"int8_weight_only",
"int4_awq",
"nvfp4",
"nvfp4_awq_lite",
"nvfp4_w4a4_weight_mse_fp8_sweep",
"w4a8_awq_beta",
"fp8_2d_blockwise_weight_only",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_w4a4_weight_local_hessian",
"mxfp8",
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)


def _canonical_qformat(name: str) -> str:
"""Resolve a user-provided qformat token to its canonical preset basename.

Lets membership checks (e.g. against :data:`_AUTO_QUANTIZE_QFORMATS`) accept
either the short alias (``int8_sq``) or the canonical YAML basename
(``int8_smoothquant``). Unknown tokens pass through unchanged so the existing
error paths still fire.
"""
return _QFORMAT_ALIASES.get(name, name)


mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -311,27 +428,11 @@ def auto_quantize(

qformat_list = args.qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
# Check if all provided quantization formats are supported. Canonicalize first so
# callers may pass either the short alias (``int8_sq``) or the canonical YAML
# basename (``int8_smoothquant``).
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
]
for qformat in qformat_list
_canonical_qformat(qformat) in _AUTO_QUANTIZE_QFORMATS for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
Expand Down Expand Up @@ -408,21 +509,16 @@ def forward_step(model, batch):

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
kv_cache_quant_cfg = copy.deepcopy(KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"])
kv_cache_quant_cfg = [
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
] # keep other quantizers from auto_quantize

if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)

mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
if not _kv_cfg_uses_constant_amax(kv_cache_quant_cfg):
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model,
Expand All @@ -446,21 +542,14 @@ def load_model(args: argparse.Namespace):
)
else:
assert args.qformat in QUANT_CFG_CHOICES, (
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
f"Quantization format is not supported for low memory mode. Supported formats: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
if args.kv_cache_qformat != "none":
if args.kv_cache_qformat != _KV_NONE:
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)
# Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights
# builds the KV quantizers with use_constant_amax already set. In calibration_only mode
# mtq.calibrate() does not re-apply quant_cfg, so this must happen before
# init_quantized_weights runs.
if args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

# Do not use real quant GEMM so the calibration can be more accurate.
with init_quantized_weights(
Expand Down Expand Up @@ -1110,7 +1199,7 @@ def _is_layerwise(obj):
)

assert args.qformat in QUANT_CFG_CHOICES, (
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}"
f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES)}"
)
quant_cfg = QUANT_CFG_CHOICES[args.qformat]

Expand All @@ -1122,14 +1211,14 @@ def _is_layerwise(obj):
args.moe_calib_experts_ratio,
)

enable_quant_kv_cache = args.kv_cache_qformat != "none"
enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")

# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"],
KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"],
)

# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
Expand All @@ -1142,14 +1231,6 @@ def _is_layerwise(obj):
quant_cfg["quant_cfg"].append({"quantizer_name": pattern, "enable": False})
print(f"Excluding MTP layer from quantization: {pattern}")

# Use constant amax for KV quantizers when a cast format is selected.
# Recipes are authoritative for KV cache config (including use_constant_amax),
# so skip this post-hoc override when --recipe is used; rely on the YAML instead
# (see modelopt_recipes/general/ptq/*_cast_kv.yaml).
if args.recipe is None and args.kv_cache_qformat in _KV_CAST_FORMATS:
quant_cfg = copy.deepcopy(quant_cfg)
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])

if needs_checkpoint_path_update(quant_cfg):
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
print(
Expand Down Expand Up @@ -1300,7 +1381,7 @@ def parse_args() -> argparse.Namespace:
"--kv_cache_qformat",
required=False,
default="fp8_cast",
choices=KV_QUANT_CFG_CHOICES.keys(),
choices=[_KV_NONE, *KV_QUANT_CFG_CHOICES],
help=(
"Specify KV cache quantization format. Default: fp8_cast. "
"Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range "
Expand Down
Loading
Loading