diff --git a/CHANGELOG.rst b/CHANGELOG.rst index be2210a33f2..cf1ce5d1927 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,8 @@ Changelog - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. - Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection `_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback. - The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically. +- ``examples/llm_ptq/hf_ptq.py`` now derives its ``--qformat`` and ``--kv_cache_qformat`` CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying a hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` table. Adding a new preset YAML makes it available on the CLI with no script change. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes). +- Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units. **Bug Fixes** diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index fb7f3d20fd9..9403cd321fa 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -18,6 +18,7 @@ import random import time import warnings +from collections.abc import Iterator, Mapping from pathlib import Path from typing import Any @@ -55,7 +56,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts -from modelopt.recipe import ModelOptPTQRecipe, load_recipe +from modelopt.recipe import ModelOptPTQRecipe, load_config, load_recipe from modelopt.torch.export import ( export_hf_checkpoint, export_hf_vllm_fq_checkpoint, @@ -66,7 +67,12 @@ save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT +from modelopt.torch.quantization.config import ( + QuantizeConfig, + _default_disabled_quantizer_cfg, + need_calibration, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.speculative.eagle.utils import ( @@ -86,56 +92,167 @@ RAND_SEED = 1234 -def _set_kv_cache_constant_amax(quant_cfg: list) -> None: - """Set use_constant_amax on KV cache quantizers. +# Preset directories under modelopt_recipes/ that back the --qformat and +# --kv_cache_qformat CLI vocabularies. Each ``*.yaml`` file in these directories is +# automatically discovered and exposed as a valid CLI value via _PresetCfgChoices, +# so no code change in this script is required when a YAML is added or removed. +# This is deliberate: every preset YAML is CLI-exposed, there is no separate +# allow-list — the directory listing is the policy. +# +# That said, prefer NOT to add new YAMLs to these preset directories either. The +# long-term direction is to retire --qformat / --kv_cache_qformat entirely in favour +# of --recipe, which accepts a full PTQ recipe (see modelopt_recipes/general/ptq/ +# and modelopt/recipe/). New quantization configurations should be authored as +# recipes, not as preset entries. +_QFORMAT_PRESET_DIR = "configs/ptq/presets/model" +_KV_QFORMAT_PRESET_DIR = "configs/ptq/presets/kv" + +# Backward-compat short names → canonical preset basename. These aliases predate the +# YAML-driven discovery below and remain accepted so existing scripts keep working. +# +# DO NOT add new entries here. New quantization formats must be exposed via their YAML +# basename under modelopt_recipes/configs/ptq/presets/model/ — the directory listing is +# the canonical CLI vocabulary. This table exists solely to keep pre-existing short +# names (and the scripts/docs that hardcode them) working through deprecation, and +# should only ever shrink. +_QFORMAT_ALIASES: dict[str, str] = { + "int8_sq": "int8_smoothquant", + "int8_wo": "int8_weight_only", + "w4a8_awq": "w4a8_awq_beta", + "nvfp4_awq": "nvfp4_awq_lite", + "nvfp4_mse": "nvfp4_w4a4_weight_mse_fp8_sweep", + "nvfp4_local_hessian": "nvfp4_w4a4_weight_local_hessian", + "fp8_pb_wo": "fp8_2d_blockwise_weight_only", + "fp8_pc_pt": "fp8_per_channel_per_token", +} + +# Sentinel value for ``--kv_cache_qformat`` meaning "no KV cache quantization". +_KV_NONE = "none" - Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. + +def _kv_cfg_uses_constant_amax(kv_quant_cfg: list[dict[str, Any]]) -> bool: + """Return True if this KV cfg pins ``use_constant_amax`` on the bmm quantizer. + + Cast-style KV presets (e.g. ``fp8_cast`` / ``nvfp4_cast``) set + ``use_constant_amax: true`` on the ``*[kv]_bmm_quantizer`` entry; that flag + means there is no data-driven calibration to run, so callers should skip + the KV-only calibration pass. Detect the property from the YAML contents + rather than from the preset name so new cast-style presets work + automatically. """ - for i, entry in enumerate(quant_cfg): + for entry in kv_quant_cfg: if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue cfg = entry.get("cfg") or {} - assert isinstance(cfg, dict) - quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} - break - - -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { - "int8": mtq.INT8_DEFAULT_CFG, - "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, - "int8_wo": mtq.INT8_WEIGHT_ONLY_CFG, - "fp8": mtq.FP8_DEFAULT_CFG, - "int4_awq": mtq.INT4_AWQ_CFG, - "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, - "nvfp4": mtq.NVFP4_DEFAULT_CFG, - "nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG, - "nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG, - "fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, - "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, - "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, - "w4a16_nvfp4": mtq.W4A16_NVFP4_CFG, - "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, - "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, - "nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG, - "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, - "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, - "mxfp8": mtq.MXFP8_DEFAULT_CFG, - "nvfp4_local_hessian": mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG, -} + return bool(cfg.get("use_constant_amax")) + return False -KV_QUANT_CFG_CHOICES = { - "none": "none", - "fp8_cast": "FP8_KV_CFG", - "fp8": "FP8_KV_CFG", - "fp8_affine": "FP8_AFFINE_KV_CFG", - "nvfp4_cast": "NVFP4_KV_CFG", - "nvfp4": "NVFP4_KV_CFG", - "nvfp4_affine": "NVFP4_AFFINE_KV_CFG", - "nvfp4_rotate": "NVFP4_KV_ROTATE_CFG", -} -# Formats that use use_constant_amax (no calibration needed). -_KV_CAST_FORMATS = {"fp8_cast", "nvfp4_cast"} +class _PresetCfgChoices(Mapping[str, dict[str, Any]]): + """Lazy mapping of qformat names → quant_cfg dicts loaded from preset YAMLs. + + Iterates the YAML files in ``modelopt_recipes//`` to populate the set + of available qformat names; the supplied ``aliases`` table maps additional + short names onto canonical preset basenames. Loading happens on first access + and is memoised so repeated lookups are cheap. + """ + + def __init__(self, subdir: str, aliases: Mapping[str, str] | None = None): + self._subdir = subdir + self._aliases: dict[str, str] = dict(aliases or {}) + self._presets: set[str] = set() + for entry in BUILTIN_CONFIG_ROOT.joinpath(subdir).iterdir(): + name = entry.name + if name.endswith((".yaml", ".yml")): + self._presets.add(name.rsplit(".", 1)[0]) + # Aliases that point at non-existent presets would silently fail at access + # time; surface this at import instead. + for alias, target in self._aliases.items(): + if target not in self._presets: + raise ValueError( + f"Alias {alias!r} points at preset {target!r} which is not present " + f"under modelopt_recipes/{subdir}/." + ) + self._cache: dict[str, dict[str, Any]] = {} + + def _canonical(self, key: str) -> str | None: + if key in self._presets: + return key + return self._aliases.get(key) + + def __contains__(self, key: object) -> bool: + return isinstance(key, str) and self._canonical(key) is not None + + def __getitem__(self, key: str) -> dict[str, Any]: + canon = self._canonical(key) + if canon is None: + raise KeyError(key) + if canon not in self._cache: + self._cache[canon] = load_config( + f"{self._subdir}/{canon}", schema_type=QuantizeConfig + ).model_dump(exclude_unset=True) + # Deepcopy on retrieval so callers can freely mutate the returned config + # (append per-model overrides, etc.) without poisoning the cached entry. + return copy.deepcopy(self._cache[canon]) + + def __iter__(self) -> Iterator[str]: + yield from sorted(self._presets | set(self._aliases)) + + def __len__(self) -> int: + return len(self._presets) + len(self._aliases) + + +QUANT_CFG_CHOICES: Mapping[str, dict[str, Any]] = _PresetCfgChoices( + _QFORMAT_PRESET_DIR, _QFORMAT_ALIASES +) +KV_QUANT_CFG_CHOICES: Mapping[str, dict[str, Any]] = _PresetCfgChoices(_KV_QFORMAT_PRESET_DIR) + +# Guard against a future ``none.yaml`` (or alias) colliding with the disable sentinel: +# argparse would silently allow both, but the runtime branch on ``!= _KV_NONE`` would +# become ambiguous and the user couldn't reach the real preset. +assert _KV_NONE not in KV_QUANT_CFG_CHOICES, ( + f"_KV_NONE sentinel {_KV_NONE!r} collides with a KV preset; rename the preset." +) + +# Formats supported by mtq.auto_quantize unified-checkpoint export. +# +# This stays hardcoded — and intentionally not derived from the preset directory — +# because auto_quantize compatibility is a property of the export path (the unified +# HF checkpoint writer, TRT-LLM consumer constraints, layer-wise mixing rules), not +# of the YAML itself. A preset can exist and be valid for plain PTQ while not being +# safe to mix into an auto_quantize search. Update this set when adding/removing a +# format from auto_quantize support. +_AUTO_QUANTIZE_QFORMATS: frozenset[str] = frozenset( + { + "fp8", + "int8_smoothquant", + "int8_weight_only", + "int4_awq", + "nvfp4", + "nvfp4_awq_lite", + "nvfp4_w4a4_weight_mse_fp8_sweep", + "w4a8_awq_beta", + "fp8_2d_blockwise_weight_only", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + "nvfp4_experts_only", + "nvfp4_omlp_only", + "nvfp4_w4a4_weight_local_hessian", + "mxfp8", + } +) + + +def _canonical_qformat(name: str) -> str: + """Resolve a user-provided qformat token to its canonical preset basename. + + Lets membership checks (e.g. against :data:`_AUTO_QUANTIZE_QFORMATS`) accept + either the short alias (``int8_sq``) or the canonical YAML basename + (``int8_smoothquant``). Unknown tokens pass through unchanged so the existing + error paths still fire. + """ + return _QFORMAT_ALIASES.get(name, name) + mto.enable_huggingface_checkpointing() @@ -311,27 +428,11 @@ def auto_quantize( qformat_list = args.qformat.split(",") assert qformat_list, "No quantization formats provided" - # Check if all provided quantization formats are supported + # Check if all provided quantization formats are supported. Canonicalize first so + # callers may pass either the short alias (``int8_sq``) or the canonical YAML + # basename (``int8_smoothquant``). assert all( - qformat - in [ - "fp8", - "int8_sq", - "int8_wo", - "int4_awq", - "nvfp4", - "nvfp4_awq", - "nvfp4_mse", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - "nvfp4_experts_only", - "nvfp4_omlp_only", - "nvfp4_local_hessian", - "mxfp8", - ] - for qformat in qformat_list + _canonical_qformat(qformat) in _AUTO_QUANTIZE_QFORMATS for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" # When language_model is a base text model without lm_head (e.g. Gemma4TextModel), @@ -408,21 +509,16 @@ def forward_step(model, batch): calibrate_loop = create_forward_loop(dataloader=calib_dataloader) # We need to explicitly set up KV cache quantization after auto_quantize - enable_quant_kv_cache = args.kv_cache_qformat != "none" + enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") if enable_quant_kv_cache: - kv_cache_quant_cfg = copy.deepcopy( - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] - ) + kv_cache_quant_cfg = copy.deepcopy(KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"]) kv_cache_quant_cfg = [ e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*" ] # keep other quantizers from auto_quantize - if args.kv_cache_qformat in _KV_CAST_FORMATS: - _set_kv_cache_constant_amax(kv_cache_quant_cfg) - mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) - if args.kv_cache_qformat not in _KV_CAST_FORMATS: + if not _kv_cfg_uses_constant_amax(kv_cache_quant_cfg): # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( language_model, @@ -446,21 +542,14 @@ def load_model(args: argparse.Namespace): ) else: assert args.qformat in QUANT_CFG_CHOICES, ( - f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}" + f"Quantization format is not supported for low memory mode. Supported formats: {list(QUANT_CFG_CHOICES)}" ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] - if args.kv_cache_qformat != "none": + if args.kv_cache_qformat != _KV_NONE: quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( quant_cfg, - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], + KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"], ) - # Mirror the use_constant_amax logic from quantize_main so that init_quantized_weights - # builds the KV quantizers with use_constant_amax already set. In calibration_only mode - # mtq.calibrate() does not re-apply quant_cfg, so this must happen before - # init_quantized_weights runs. - if args.kv_cache_qformat in _KV_CAST_FORMATS: - quant_cfg = copy.deepcopy(quant_cfg) - _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) # Do not use real quant GEMM so the calibration can be more accurate. with init_quantized_weights( @@ -1110,7 +1199,7 @@ def _is_layerwise(obj): ) assert args.qformat in QUANT_CFG_CHOICES, ( - f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES.keys())}" + f"Unsupported quantization format: {args.qformat}, choices are: {list(QUANT_CFG_CHOICES)}" ) quant_cfg = QUANT_CFG_CHOICES[args.qformat] @@ -1122,14 +1211,14 @@ def _is_layerwise(obj): args.moe_calib_experts_ratio, ) - enable_quant_kv_cache = args.kv_cache_qformat != "none" + enable_quant_kv_cache = args.kv_cache_qformat != _KV_NONE print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. if enable_quant_kv_cache: quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant( quant_cfg, - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], + KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]["quant_cfg"], ) # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) @@ -1142,14 +1231,6 @@ def _is_layerwise(obj): quant_cfg["quant_cfg"].append({"quantizer_name": pattern, "enable": False}) print(f"Excluding MTP layer from quantization: {pattern}") - # Use constant amax for KV quantizers when a cast format is selected. - # Recipes are authoritative for KV cache config (including use_constant_amax), - # so skip this post-hoc override when --recipe is used; rely on the YAML instead - # (see modelopt_recipes/general/ptq/*_cast_kv.yaml). - if args.recipe is None and args.kv_cache_qformat in _KV_CAST_FORMATS: - quant_cfg = copy.deepcopy(quant_cfg) - _set_kv_cache_constant_amax(quant_cfg["quant_cfg"]) - if needs_checkpoint_path_update(quant_cfg): quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path) print( @@ -1300,7 +1381,7 @@ def parse_args() -> argparse.Namespace: "--kv_cache_qformat", required=False, default="fp8_cast", - choices=KV_QUANT_CFG_CHOICES.keys(), + choices=[_KV_NONE, *KV_QUANT_CFG_CHOICES], help=( "Specify KV cache quantization format. Default: fp8_cast. " "Formats ending in '_cast' (fp8_cast, nvfp4_cast) set the amax to FP8 range " diff --git a/modelopt_recipes/configs/ptq/presets/kv/fp8_cast.yaml b/modelopt_recipes/configs/ptq/presets/kv/fp8_cast.yaml new file mode 100644 index 00000000000..e689a17ad4e --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/fp8_cast.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables FP8 E4M3 KV-cache quantizers with +# ``use_constant_amax`` (no data-driven calibration required). + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +quant_cfg: + - $import: kv_fp8_cast diff --git a/modelopt_recipes/configs/ptq/presets/kv/nvfp4_cast.yaml b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_cast.yaml new file mode 100644 index 00000000000..665e20fe4fa --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_cast.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables NVFP4 KV-cache quantizers with +# ``use_constant_amax`` (no data-driven calibration required). + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_nvfp4_cast: configs/ptq/units/kv_nvfp4_cast + +quant_cfg: + - $import: kv_nvfp4_cast