diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6d27aa593f..1e0b243c3f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -55,7 +55,12 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts -from modelopt.recipe import ModelOptPTQRecipe, load_recipe +from modelopt.recipe import ( + ModelOptAutoQuantizeRecipe, + ModelOptPTQRecipe, + ModelOptRecipeBase, + load_recipe, +) from modelopt.torch.export import ( export_hf_checkpoint, export_hf_vllm_fq_checkpoint, @@ -208,6 +213,7 @@ def make_calib_dataloader( tokenizer: PreTrainedTokenizerBase | None, device: torch.device, model_type: str | None, + recipe: ModelOptRecipeBase | None = None, ) -> tuple[DataLoader | _DeviceDataLoader, str | None]: calib_dataloader = None first_text_speech_dataset = None @@ -271,8 +277,12 @@ def make_calib_dataloader( assert tokenizer is not None and isinstance( tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) ), "The PreTrainedTokenizer must be set" - # Labels are only needed for gradient-based auto_quantize - include_labels = ( + # Labels are only needed for gradient-based auto_quantize (CLI or recipe path). + is_autoquant_recipe_gradient = ( + isinstance(recipe, ModelOptAutoQuantizeRecipe) + and recipe.auto_quantize.method == "gradient" + ) + include_labels = is_autoquant_recipe_gradient or ( args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" ) @@ -292,12 +302,21 @@ def auto_quantize( args: argparse.Namespace, language_model: torch.nn.Module, calib_dataloader: DataLoader, - auto_quantize_method="gradient", - auto_quantize_score_size=128, - auto_quantize_checkpoint=None, full_model: torch.nn.Module | None = None, + *, + auto_quantize_method: str, + auto_quantize_score_size: int, + auto_quantize_checkpoint: str | None, + constraints: dict, + quantization_formats: list[dict], + disabled_layers: list[str], + kv_cache_quant_cfg: dict | None, ): - """Auto search quantization of multiple formats.""" + """Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize, + run KV cache post-step. All knobs are explicit keyword-only args; the + caller (dispatch site in ``quantize_main``) is responsible for resolving + them from either CLI args or a recipe before invoking this function. + """ if args.calib_with_images: raise NotImplementedError( @@ -305,35 +324,10 @@ def auto_quantize( "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." ) - assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( + assert args.inference_pipeline_parallel <= 1, ( "Auto Quantization is not supported for pipeline parallel size > 1" ) - qformat_list = args.qformat.split(",") - assert qformat_list, "No quantization formats provided" - # Check if all provided quantization formats are supported - assert all( - qformat - in [ - "fp8", - "int8_sq", - "int8_wo", - "int4_awq", - "nvfp4", - "nvfp4_awq", - "nvfp4_mse", - "w4a8_awq", - "fp8_pb_wo", - "w4a8_mxfp4_fp8", - "nvfp4_mlp_only", - "nvfp4_experts_only", - "nvfp4_omlp_only", - "nvfp4_local_hessian", - "mxfp8", - ] - for qformat in qformat_list - ), "One or more quantization formats provided are not supported for unified checkpoint export" - # When language_model is a base text model without lm_head (e.g. Gemma4TextModel), # use full_model's lm_head to compute logits/loss from hidden states. is_base_model = ( @@ -384,49 +378,42 @@ def forward_step(model, batch): language_model, _ = mtq.auto_quantize( language_model, - constraints={"effective_bits": args.auto_quantize_bits}, + constraints=constraints, data_loader=calib_dataloader, forward_step=forward_step, loss_func=loss_func, # Only used for gradient-based method # TRTLLM only support one quantization format or None (do not quantize, internally supported) - quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], + quantization_formats=quantization_formats, # type: ignore[arg-type] num_calib_steps=len(calib_dataloader), # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. num_score_steps=min( len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) ), verbose=True, - # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=[ - entry["quantizer_name"] - for entry in _default_disabled_quantizer_cfg - if "parent_class" not in entry - ], + disabled_layers=disabled_layers, method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, ) calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - # We need to explicitly set up KV cache quantization after auto_quantize - enable_quant_kv_cache = args.kv_cache_qformat != "none" - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") - if enable_quant_kv_cache: - kv_cache_quant_cfg = copy.deepcopy( - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"] - ) - kv_cache_quant_cfg = [ - e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*" + print(f"{'Enable' if kv_cache_quant_cfg is not None else 'Disable'} KV cache quantization") + if kv_cache_quant_cfg is not None: + kv_entries = [ + e for e in copy.deepcopy(kv_cache_quant_cfg["quant_cfg"]) if e["quantizer_name"] != "*" ] # keep other quantizers from auto_quantize - if args.kv_cache_qformat in _KV_CAST_FORMATS: - _set_kv_cache_constant_amax(kv_cache_quant_cfg) - - mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg) - if args.kv_cache_qformat not in _KV_CAST_FORMATS: + mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_entries) + # Calibrate only when at least one KV entry doesn't pin amax via use_constant_amax. + # Cast-variant presets (kv_fp8_cast, kv_nvfp4_cast) bake this in; data-driven + # variants (kv_fp8, kv_nvfp4, etc.) need a calibration pass. + needs_calibration = not all( + (e.get("cfg") or {}).get("use_constant_amax") is True for e in kv_entries + ) + if needs_calibration: # Calibrate only the KV cache quantizers; disable all others. with mtq.set_quantizer_by_cfg_context( language_model, - [{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg], + [{"quantizer_name": "*", "enable": False}, *kv_entries], ): mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop) return language_model @@ -987,12 +974,20 @@ def quantize_main( ): # Load the recipe up front so we can detect layerwise calibration before batch-size probing. recipe = None - if args.recipe is not None and not args.auto_quantize_bits: + if args.recipe is not None: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) - if not isinstance(recipe, ModelOptPTQRecipe): + if not isinstance(recipe, (ModelOptPTQRecipe, ModelOptAutoQuantizeRecipe)): raise TypeError( - f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" + f"Expected PTQ or AutoQuantize recipe, but got {type(recipe).__name__} " + f"from {args.recipe}" + ) + # Fail-fast on conflicting budget sources: a recipe carries its own + # effective_bits, so silently honoring one over the other would be a + # reproducibility hazard. + if args.auto_quantize_bits is not None: + raise ValueError( + "Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget." ) def _is_layerwise(obj): @@ -1043,7 +1038,9 @@ def _is_layerwise(obj): else: sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None or isinstance( + recipe, ModelOptAutoQuantizeRecipe + ) args.batch_size = get_max_batch_size( language_model, @@ -1057,7 +1054,7 @@ def _is_layerwise(obj): print(f"Use calib batch_size {args.batch_size}") calib_dataloader, first_text_speech_dataset = make_calib_dataloader( - args, language_model, processor, tokenizer, device, model_type + args, language_model, processor, tokenizer, device, model_type, recipe=recipe ) # Detect if this is a Nemotron VL model using architecture-based detection @@ -1067,20 +1064,104 @@ def _is_layerwise(obj): args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model ) - if args.auto_quantize_bits: - assert len(args.qformat.split(",")) > 1, ( - "Auto quantization needs multiple quantization format." - ) + # All auto_quantize() knobs are resolved here before calling the helper. + # Helper is a leaf orchestrator — it does not know whether inputs came from + # CLI args or a recipe. + if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits is not None: + default_disabled_layers = [ + entry["quantizer_name"] + for entry in _default_disabled_quantizer_cfg + if "parent_class" not in entry + ] - auto_quantize( - args, - language_model, - calib_dataloader, - auto_quantize_method=args.auto_quantize_method, - auto_quantize_score_size=args.auto_quantize_score_size, - auto_quantize_checkpoint=args.auto_quantize_checkpoint, - full_model=full_model, - ) + # Resolve --kv_cache_qformat to a full QuantizeConfig dict (or None). Used as the + # CLI fallback when a recipe is silent on KV cache, and as the sole source for the + # CLI autoquant branch. Cast variants get use_constant_amax injected at this layer + # so the helper can stay format-agnostic (it just checks use_constant_amax to + # decide whether to calibrate). + def _cli_kv_cache_quant_cfg(): + if args.kv_cache_qformat == "none": + return None + cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])) + if args.kv_cache_qformat in _KV_CAST_FORMATS: + _set_kv_cache_constant_amax(cfg["quant_cfg"]) + return cfg + + if isinstance(recipe, ModelOptAutoQuantizeRecipe): + aq = recipe.auto_quantize + + # mtq.auto_quantize labels candidates by upstream identity: dicts that ARE + # an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG); + # all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning. + # Recipe candidates come from .model_dump() — equal by value but not identity, + # so we'd lose the friendly names. Substitute the canonical object back when + # the dump matches a known preset, so logs and the warning line up with CLI. + # The match check uses exclude_unset=True so it compares against the + # preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled + # defaults). The payload still passes the full dump to upstream. + def _candidate_for_mtq(fmt): + strict = fmt.model_dump(exclude_unset=True) + for cfg in QUANT_CFG_CHOICES.values(): + if cfg == strict: + return cfg + return fmt.model_dump() + + auto_quantize( + args, + language_model, + calib_dataloader, + full_model=full_model, + auto_quantize_method=aq.method, + auto_quantize_score_size=aq.num_score_steps, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, + constraints=aq.constraints.model_dump(exclude_none=True), + quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats], + disabled_layers=aq.disabled_layers or default_disabled_layers, + kv_cache_quant_cfg=( + aq.kv_cache.model_dump() + if aq.kv_cache is not None + else _cli_kv_cache_quant_cfg() + ), + ) + else: + qformat_list = args.qformat.split(",") + assert len(qformat_list) > 1, "Auto quantization needs multiple quantization format." + assert all( + qformat + in [ + "fp8", + "int8_sq", + "int8_wo", + "int4_awq", + "nvfp4", + "nvfp4_awq", + "nvfp4_mse", + "w4a8_awq", + "fp8_pb_wo", + "w4a8_mxfp4_fp8", + "nvfp4_mlp_only", + "nvfp4_experts_only", + "nvfp4_omlp_only", + "nvfp4_local_hessian", + "mxfp8", + ] + for qformat in qformat_list + ), ( + "One or more quantization formats provided are not supported for unified checkpoint export" + ) + auto_quantize( + args, + language_model, + calib_dataloader, + full_model=full_model, + auto_quantize_method=args.auto_quantize_method, + auto_quantize_score_size=args.auto_quantize_score_size, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, + constraints={"effective_bits": args.auto_quantize_bits}, + quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list], + disabled_layers=default_disabled_layers, + kv_cache_quant_cfg=_cli_kv_cache_quant_cfg(), + ) else: # mono quantization @@ -1198,9 +1279,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--recipe", help=( - "PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, " - "general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). " - "When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config." + "PTQ or AutoQuantize recipe YAML file or name without suffix " + "(e.g. general/ptq/nvfp4_default-kv_fp8_cast, " + "general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). " + "PTQ recipes fully own quant config; AutoQuantize recipes own search config " + "and may optionally override --kv_cache_qformat via their kv_cache field." ), default=None, ) diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 749d80a933..218bac82f4 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,8 +19,9 @@ import warnings from enum import Enum +from typing import Literal -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001 @@ -36,6 +37,7 @@ class RecipeType(str, Enum): """List of recipe types. See ``RECIPE_TYPE_TO_CLASS`` at the bottom for the schema mapping.""" PTQ = "ptq" + AUTO_QUANTIZE = "auto_quantize" SPECULATIVE_EAGLE = "speculative_eagle" SPECULATIVE_DFLASH = "speculative_dflash" SPECULATIVE_MEDUSA = "speculative_medusa" @@ -104,6 +106,93 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): ) +class AutoQuantizeConstraints(ModeloptBaseConfig): + """Constraints passed to ``mtq.auto_quantize`` (matches its dict shape). + + Today only ``effective_bits`` is supported upstream. When new constraint + keys land (e.g., ``cost_model`` / ``cost`` from PR #1497), add them as + fields here so ``.model_dump(exclude_none=True)`` produces the dict + upstream expects. + """ + + effective_bits: float = ModeloptField( + default=4.8, + title="Effective bits per weight", + description="Average weight-storage bits target for the LP, in (0, 16].", + ) + + @field_validator("effective_bits") + @classmethod + def _validate_effective_bits(cls, v: float) -> float: + if not (0 < v <= 16): + raise ValueError(f"effective_bits must be in (0, 16], got {v}") + return v + + +class AutoQuantizeConfig(ModeloptBaseConfig): + """Schema for the ``auto_quantize`` block in an AutoQuantize recipe.""" + + constraints: AutoQuantizeConstraints = Field( + title="Search constraints + cost model", + description="LP budget and cost model.", + ) + + candidate_formats: list[QuantizeConfig] = ModeloptField( + default=[], + title="Candidate quantization formats", + description="Per-layer search space; each entry is a full QuantizeConfig. " + "At least 2 entries required.", + ) + + method: Literal["gradient", "kl_div"] = ModeloptField( + default="gradient", + title="Sensitivity scoring method", + description="'gradient' (Taylor + Fisher, needs labels) or 'kl_div' (no labels).", + ) + + num_score_steps: int = ModeloptField( + default=128, + title="Phase-3 scoring sample count", + description="Number of batches for sensitivity scoring.", + ) + + disabled_layers: list[str] = ModeloptField( + default=[], + title="Excluded layer patterns", + description="Glob patterns; matching layers are excluded from the search.", + ) + + kv_cache: QuantizeConfig | None = ModeloptField( + default=None, + title="KV cache QuantizeConfig (optional)", + description="Optional full QuantizeConfig applied as a uniform post-step after the " + "LP search. Typically uses ``$import: configs/ptq/units/kv_*`` for a built-in KV " + "preset, or inlines a custom config. If omitted, the runtime --kv_cache_qformat " + "CLI flag is used as a fallback.", + ) + + @field_validator("candidate_formats") + @classmethod + def _at_least_two_candidates(cls, v: list[QuantizeConfig]) -> list[QuantizeConfig]: + if len(v) < 2: + raise ValueError( + "auto_quantize requires at least 2 candidate_formats. " + "For uniform quantization, use a PTQ recipe instead." + ) + return v + + +class ModelOptAutoQuantizeRecipe(ModelOptRecipeBase): + """Our config class for AutoQuantize recipes.""" + + metadata: RecipeMetadataConfig = _metadata_field(RecipeType.AUTO_QUANTIZE) + + auto_quantize: AutoQuantizeConfig = Field( + title="AutoQuantize config", + description="AutoQuantize search configuration. Required.", + ) + + class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase): """Base class for speculative-decoding recipes. @@ -199,6 +288,7 @@ class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase): # uses this for typed-list ``$import`` resolution; add a new entry when introducing a recipe. RECIPE_TYPE_TO_CLASS: dict[RecipeType, type[ModelOptRecipeBase]] = { RecipeType.PTQ: ModelOptPTQRecipe, + RecipeType.AUTO_QUANTIZE: ModelOptAutoQuantizeRecipe, RecipeType.SPECULATIVE_EAGLE: ModelOptEagleRecipe, RecipeType.SPECULATIVE_DFLASH: ModelOptDFlashRecipe, RecipeType.SPECULATIVE_MEDUSA: ModelOptMedusaRecipe, diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 0a9218ff7d..1e78c9372d 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -42,6 +42,7 @@ # must contain 'quantize'" instead of pydantic's generic missing-field error. _REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = { RecipeType.PTQ: "quantize", + RecipeType.AUTO_QUANTIZE: "auto_quantize", RecipeType.SPECULATIVE_EAGLE: "eagle", RecipeType.SPECULATIVE_DFLASH: "dflash", RecipeType.SPECULATIVE_MEDUSA: "medusa", @@ -171,8 +172,12 @@ def _load_recipe_from_file( raw = yaml.safe_load(recipe_file.read_text()) or {} if not isinstance(raw, dict) or required_section not in raw: + # Speculative recipes use the family suffix ("EAGLE" not "SPECULATIVE_EAGLE"); + # every other multi-word recipe type uses the full value ("AUTO_QUANTIZE", not "QUANTIZE"). kind = ( - rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper() + rtype.value.removeprefix("speculative_").upper() + if rtype.value.startswith("speculative_") + else rtype.value.upper() ) raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.") diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index e4e633e36a..ba83139bcb 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -49,9 +49,16 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. - Right now, we find the minimum compression ratio across all quantizer attribute configs. - This is not perfect but is a good proxy for the overall compression ratio. We will improve - this in future releases. + If ``quant_cfg.effective_bits`` is set, returns ``effective_bits / 16`` directly. This + is the override path for formats whose true effective bits don't match the per-quantizer + ``num_bits`` heuristic — e.g., NVFP4 has 4 value bits + a per-16-element FP8 scale + (8/16 = 0.5 bits/element), so true effective bits = 4.5, not the heuristic's 4.0. + + Otherwise, falls back to the heuristic: minimum compression ratio across all enabled + quantizer attribute configs (``num_bits / 16`` for ints, ``(E + M + 1) / 16`` for FP + tuples). This is a good proxy for the overall compression ratio of formats without + block-scale overhead, but under-counts block-quantized formats. We will improve this + in future releases. Args: quant_cfg: The quantization configuration to estimate compression for. @@ -59,6 +66,8 @@ def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: Returns: float: The estimated compression ratio (0.0 to 1.0). """ + if quant_cfg.effective_bits is not None: + return quant_cfg.effective_bits / 16.0 def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if isinstance(quantizer_attr_cfg, list): diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index fd95171ce4..0d8cf47684 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1160,6 +1160,25 @@ class QuantizeConfig(ModeloptBaseConfig): validate_default=True, ) + effective_bits: float | None = ModeloptField( + default=None, + title="Effective bits per element (autoquant cost override)", + description=( + "Optional override for the autoquant LP cost model. If set, replaces the " + "heuristic estimate derived from ``num_bits``. Mainly useful for block-quantized " + "formats where the heuristic under-counts due to per-block scale overhead " + "(e.g., NVFP4 actual=4.5 vs heuristic=4.0). Must be in (0, 16] when set. " + "Read only by autoquant; other quantization paths ignore this field." + ), + ) + + @field_validator("effective_bits") + @classmethod + def _validate_effective_bits(cls, v: float | None) -> float | None: + if v is not None and not (0 < v <= 16): + raise ValueError(f"effective_bits must be in (0, 16], got {v}") + return v + @field_validator("quant_cfg", mode="before") @classmethod def normalize_quant_cfg( diff --git a/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml new file mode 100644 index 0000000000..c4b9a71c11 --- /dev/null +++ b/modelopt_recipes/general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast.yaml @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AutoQuantize recipe: mixed NVFP4 + FP8 per-layer search at 4.8 effective bits, +# FP8 KV cache (cast mode). Gradient-based sensitivity scoring; weight cost model. + +imports: + nvfp4: configs/ptq/presets/model/nvfp4 + fp8: configs/ptq/presets/model/fp8 + kv_fp8_cast: configs/ptq/units/kv_fp8_cast + +metadata: + recipe_type: auto_quantize + description: Mixed NVFP4 + FP8 at 4.8 effective bits with FP8 KV cache (cast). + +auto_quantize: + constraints: + effective_bits: 4.8 + + candidate_formats: + # NVFP4 true effective bits = 4 value bits + 8-bit FP8 scale per 16-element block + # = 4 + 0.5 = 4.5 bits/element. Override the heuristic's 4.0 so the LP cost is accurate. + - $import: nvfp4 + effective_bits: 4.5 + # FP8 effective bits = 8 (heuristic is correct, per-tensor scale is negligible). + - $import: fp8 + + kv_cache: + quant_cfg: + - $import: kv_fp8_cast + + method: gradient + num_score_steps: 128 + + disabled_layers: + - "*lm_head*" diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 4c4e2d07de..1927ce482d 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -20,7 +20,9 @@ import pytest +import modelopt.torch.quantization as mtq from modelopt.recipe.config import ( + ModelOptAutoQuantizeRecipe, ModelOptDFlashRecipe, ModelOptEagleRecipe, ModelOptPTQRecipe, @@ -243,6 +245,115 @@ def test_load_recipe_dir_missing_quantize_raises(tmp_path): load_recipe(tmp_path) +# --------------------------------------------------------------------------- +# load_recipe — AutoQuantize recipes +# --------------------------------------------------------------------------- + + +_AQ_MINIMAL_BODY = ( + "metadata:\n" + " recipe_type: auto_quantize\n" + "auto_quantize:\n" + " constraints:\n" + " effective_bits: 4.8\n" + " candidate_formats:\n" + " - algorithm: max\n" + " quant_cfg: []\n" + " - algorithm: max\n" + " quant_cfg: []\n" +) + + +def test_load_recipe_autoquantize_builtin(): + """load_recipe loads the built-in AutoQuantize recipe.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + assert recipe.recipe_type == RecipeType.AUTO_QUANTIZE + assert isinstance(recipe, ModelOptAutoQuantizeRecipe) + aq = recipe.auto_quantize + assert aq.constraints.effective_bits == 4.8 + assert len(aq.candidate_formats) == 2 + # kv_cache is a full QuantizeConfig now (not a hardcoded qformat string). + assert aq.kv_cache is not None + assert aq.kv_cache.algorithm == "max" + assert len(aq.kv_cache.quant_cfg) >= 1 + + +def test_load_recipe_autoquantize_defaults(): + """Optional AutoQuantize fields use Pydantic defaults when omitted.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + aq = recipe.auto_quantize + assert aq.method == "gradient" + assert aq.num_score_steps == 128 + + +def test_load_recipe_autoquantize_candidates_match_presets(): + """Built-in AutoQuantize recipe's $imported candidates equal preset + inline override.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + candidates = recipe.auto_quantize.candidate_formats + + # NVFP4 candidate = canonical preset + inline effective_bits override. + expected_nvfp4 = {**mtq.NVFP4_DEFAULT_CFG, "effective_bits": 4.5} + assert candidates[0].model_dump(exclude_unset=True) == expected_nvfp4 + + # FP8 candidate = canonical preset exactly (no override). + assert candidates[1].model_dump(exclude_unset=True) == mtq.FP8_DEFAULT_CFG + + +def test_load_recipe_autoquantize_missing_section_raises(tmp_path): + """An AutoQuantize recipe missing the ``auto_quantize`` section is rejected + with the clean loader-level error (not the generic pydantic missing-field one).""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: auto_quantize\n") + with pytest.raises( + ValueError, match=r"AUTO_QUANTIZE recipe file .* must contain 'auto_quantize'" + ): + load_recipe(bad) + + +def test_load_recipe_autoquantize_too_few_candidates_raises(tmp_path): + """candidate_formats with fewer than 2 entries is rejected.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n" + " recipe_type: auto_quantize\n" + "auto_quantize:\n" + " constraints:\n" + " effective_bits: 4.8\n" + " candidate_formats:\n" + " - algorithm: max\n" + " quant_cfg: []\n" + ) + with pytest.raises(ValueError, match="at least 2"): + load_recipe(bad) + + +def test_load_recipe_autoquantize_effective_bits_out_of_range_raises(tmp_path): + """effective_bits outside (0, 16] is rejected.""" + bad = tmp_path / "bad.yml" + bad.write_text(_AQ_MINIMAL_BODY.replace("effective_bits: 4.8", "effective_bits: 20")) + with pytest.raises(ValueError, match="effective_bits"): + load_recipe(bad) + + +def test_load_recipe_autoquantize_kv_cache_optional(tmp_path): + """kv_cache is optional; recipes without it parse fine and aq.kv_cache is None.""" + recipe_file = tmp_path / "aq.yml" + recipe_file.write_text(_AQ_MINIMAL_BODY) + recipe = load_recipe(recipe_file) + assert recipe.auto_quantize.kv_cache is None + + +def test_load_recipe_autoquantize_effective_bits_inline_override(): + """Inline $import + sibling effective_bits merge applied per candidate.""" + recipe = load_recipe("general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast") + candidates = recipe.auto_quantize.candidate_formats + + # NVFP4 candidate carries the override. + assert candidates[0].effective_bits == 4.5 + # FP8 candidate has no override; heuristic still applies. + assert candidates[1].effective_bits is None + + # --------------------------------------------------------------------------- # load_recipe — EAGLE speculative decoding # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e..7ab308079c 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -375,6 +375,32 @@ def test_estimate_quant_compression(): assert estimate_quant_compression(fp8_affine_kv_cfg) == 0.5 +def test_estimate_quant_compression_effective_bits_override(): + """``QuantizeConfig.effective_bits`` overrides the per-quantizer num_bits heuristic. + + Validates two things: + 1. The override path returns ``effective_bits / 16`` and bypasses the heuristic. + 2. Without the override, the heuristic returns the unchanged baseline value. + """ + # NVFP4 — heuristic returns 4.0 bits / 16 = 0.25, but true effective bits is 4.5. + nvfp4_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG) + assert nvfp4_cfg.effective_bits is None + assert estimate_quant_compression(nvfp4_cfg) == 0.25 # heuristic baseline + + nvfp4_cfg_overridden = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=4.5) + assert estimate_quant_compression(nvfp4_cfg_overridden) == 4.5 / 16.0 + + # Override can also represent a higher cost (e.g., conservative for a sensitive recipe). + nvfp4_cfg_high = mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=16.0) + assert estimate_quant_compression(nvfp4_cfg_high) == 1.0 + + # Out-of-range values are rejected by the Pydantic validator. + with pytest.raises(ValueError, match="effective_bits must be in"): + mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=0.0) + with pytest.raises(ValueError, match="effective_bits must be in"): + mtq.config.QuantizeConfig(**mtq.NVFP4_DEFAULT_CFG, effective_bits=17.0) + + @pytest.mark.parametrize("method", ["gradient", "kl_div"]) def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): """Test that checkpoint can be used to resume an interrupted search."""