Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 160 additions & 77 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
from modelopt.recipe import (
ModelOptAutoQuantizeRecipe,
ModelOptPTQRecipe,
ModelOptRecipeBase,
load_recipe,
)
from modelopt.torch.export import (
export_hf_checkpoint,
export_hf_vllm_fq_checkpoint,
Expand Down Expand Up @@ -208,6 +213,7 @@ def make_calib_dataloader(
tokenizer: PreTrainedTokenizerBase | None,
device: torch.device,
model_type: str | None,
recipe: ModelOptRecipeBase | None = None,
) -> tuple[DataLoader | _DeviceDataLoader, str | None]:
calib_dataloader = None
first_text_speech_dataset = None
Expand Down Expand Up @@ -271,8 +277,12 @@ def make_calib_dataloader(
assert tokenizer is not None and isinstance(
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
), "The PreTrainedTokenizer must be set"
# Labels are only needed for gradient-based auto_quantize
include_labels = (
# Labels are only needed for gradient-based auto_quantize (CLI or recipe path).
is_autoquant_recipe_gradient = (
isinstance(recipe, ModelOptAutoQuantizeRecipe)
and recipe.auto_quantize.method == "gradient"
)
include_labels = is_autoquant_recipe_gradient or (
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
)

Expand All @@ -292,48 +302,32 @@ def auto_quantize(
args: argparse.Namespace,
language_model: torch.nn.Module,
calib_dataloader: DataLoader,
auto_quantize_method="gradient",
auto_quantize_score_size=128,
auto_quantize_checkpoint=None,
full_model: torch.nn.Module | None = None,
*,
auto_quantize_method: str,
auto_quantize_score_size: int,
auto_quantize_checkpoint: str | None,
constraints: dict,
quantization_formats: list[dict],
disabled_layers: list[str],
kv_cache_quant_cfg: dict | None,
):
"""Auto search quantization of multiple formats."""
"""Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize,
run KV cache post-step. All knobs are explicit keyword-only args; the
caller (dispatch site in ``quantize_main``) is responsible for resolving
them from either CLI args or a recipe before invoking this function.
"""

if args.calib_with_images:
raise NotImplementedError(
"AutoQuantize with image-text calibration is not supported yet. "
"Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
)

assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), (
assert args.inference_pipeline_parallel <= 1, (
"Auto Quantization is not supported for pipeline parallel size > 1"
)

qformat_list = args.qformat.split(",")
assert qformat_list, "No quantization formats provided"
# Check if all provided quantization formats are supported
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"

# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
# use full_model's lm_head to compute logits/loss from hidden states.
is_base_model = (
Expand Down Expand Up @@ -384,49 +378,42 @@ def forward_step(model, batch):

language_model, _ = mtq.auto_quantize(
language_model,
constraints={"effective_bits": args.auto_quantize_bits},
constraints=constraints,
data_loader=calib_dataloader,
forward_step=forward_step,
loss_func=loss_func, # Only used for gradient-based method
# TRTLLM only support one quantization format or None (do not quantize, internally supported)
quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list],
quantization_formats=quantization_formats, # type: ignore[arg-type]
num_calib_steps=len(calib_dataloader),
# AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
num_score_steps=min(
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
),
verbose=True,
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
disabled_layers=[
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry
],
disabled_layers=disabled_layers,
method=auto_quantize_method,
checkpoint=auto_quantize_checkpoint,
)

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# We need to explicitly set up KV cache quantization after auto_quantize
enable_quant_kv_cache = args.kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
if enable_quant_kv_cache:
kv_cache_quant_cfg = copy.deepcopy(
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
)
kv_cache_quant_cfg = [
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
print(f"{'Enable' if kv_cache_quant_cfg is not None else 'Disable'} KV cache quantization")
if kv_cache_quant_cfg is not None:
kv_entries = [
e for e in copy.deepcopy(kv_cache_quant_cfg["quant_cfg"]) if e["quantizer_name"] != "*"
] # keep other quantizers from auto_quantize

if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(kv_cache_quant_cfg)

mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_entries)
# Calibrate only when at least one KV entry doesn't pin amax via use_constant_amax.
# Cast-variant presets (kv_fp8_cast, kv_nvfp4_cast) bake this in; data-driven
# variants (kv_fp8, kv_nvfp4, etc.) need a calibration pass.
needs_calibration = not all(
(e.get("cfg") or {}).get("use_constant_amax") is True for e in kv_entries
)
if needs_calibration:
# Calibrate only the KV cache quantizers; disable all others.
with mtq.set_quantizer_by_cfg_context(
language_model,
[{"quantizer_name": "*", "enable": False}, *kv_cache_quant_cfg],
[{"quantizer_name": "*", "enable": False}, *kv_entries],
):
mtq.calibrate(language_model, algorithm="max", forward_loop=calibrate_loop)
return language_model
Expand Down Expand Up @@ -987,12 +974,20 @@ def quantize_main(
):
# Load the recipe up front so we can detect layerwise calibration before batch-size probing.
recipe = None
if args.recipe is not None and not args.auto_quantize_bits:
if args.recipe is not None:
print(f"Use recipe {args.recipe} for quantization")
recipe = load_recipe(args.recipe)
if not isinstance(recipe, ModelOptPTQRecipe):
if not isinstance(recipe, (ModelOptPTQRecipe, ModelOptAutoQuantizeRecipe)):
raise TypeError(
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
f"Expected PTQ or AutoQuantize recipe, but got {type(recipe).__name__} "
f"from {args.recipe}"
)
# Fail-fast on conflicting budget sources: a recipe carries its own
# effective_bits, so silently honoring one over the other would be a
# reproducibility hazard.
if args.auto_quantize_bits is not None:
raise ValueError(
"Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget."
)

def _is_layerwise(obj):
Expand Down Expand Up @@ -1043,7 +1038,9 @@ def _is_layerwise(obj):
else:
sample_input_single_batch = None

run_auto_quant = args.auto_quantize_bits is not None
run_auto_quant = args.auto_quantize_bits is not None or isinstance(
recipe, ModelOptAutoQuantizeRecipe
)

args.batch_size = get_max_batch_size(
language_model,
Expand All @@ -1057,7 +1054,7 @@ def _is_layerwise(obj):
print(f"Use calib batch_size {args.batch_size}")

calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
args, language_model, processor, tokenizer, device, model_type
args, language_model, processor, tokenizer, device, model_type, recipe=recipe
)

# Detect if this is a Nemotron VL model using architecture-based detection
Expand All @@ -1067,20 +1064,104 @@ def _is_layerwise(obj):
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)

if args.auto_quantize_bits:
assert len(args.qformat.split(",")) > 1, (
"Auto quantization needs multiple quantization format."
)
# All auto_quantize() knobs are resolved here before calling the helper.
# Helper is a leaf orchestrator — it does not know whether inputs came from
# CLI args or a recipe.
if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think auto_quantize may not have that many users, so we may want to just remove the args support for auto quantize, move completely to yaml recipes. How do you guys think @realAsma @meenchen

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think we can remove the auto_quantize specific args that are part of the proposed recipe now

default_disabled_layers = [
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry
]

auto_quantize(
args,
language_model,
calib_dataloader,
auto_quantize_method=args.auto_quantize_method,
auto_quantize_score_size=args.auto_quantize_score_size,
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
full_model=full_model,
)
# Resolve --kv_cache_qformat to a full QuantizeConfig dict (or None). Used as the
# CLI fallback when a recipe is silent on KV cache, and as the sole source for the
# CLI autoquant branch. Cast variants get use_constant_amax injected at this layer
# so the helper can stay format-agnostic (it just checks use_constant_amax to
# decide whether to calibrate).
def _cli_kv_cache_quant_cfg():
if args.kv_cache_qformat == "none":
return None
cfg = copy.deepcopy(getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat]))
if args.kv_cache_qformat in _KV_CAST_FORMATS:
_set_kv_cache_constant_amax(cfg["quant_cfg"])
return cfg

if isinstance(recipe, ModelOptAutoQuantizeRecipe):
aq = recipe.auto_quantize

# mtq.auto_quantize labels candidates by upstream identity: dicts that ARE
# an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG);
# all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning.
# Recipe candidates come from .model_dump() — equal by value but not identity,
# so we'd lose the friendly names. Substitute the canonical object back when
# the dump matches a known preset, so logs and the warning line up with CLI.
# The match check uses exclude_unset=True so it compares against the
# preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled
# defaults). The payload still passes the full dump to upstream.
def _candidate_for_mtq(fmt):
strict = fmt.model_dump(exclude_unset=True)
for cfg in QUANT_CFG_CHOICES.values():
if cfg == strict:
return cfg
return fmt.model_dump()

auto_quantize(
args,
language_model,
calib_dataloader,
full_model=full_model,
auto_quantize_method=aq.method,
auto_quantize_score_size=aq.num_score_steps,
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
constraints=aq.constraints.model_dump(exclude_none=True),
quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats],
disabled_layers=aq.disabled_layers or default_disabled_layers,
kv_cache_quant_cfg=(
aq.kv_cache.model_dump()
if aq.kv_cache is not None
else _cli_kv_cache_quant_cfg()
),
)
else:
qformat_list = args.qformat.split(",")
assert len(qformat_list) > 1, "Auto quantization needs multiple quantization format."
assert all(
qformat
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"nvfp4_experts_only",
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
]
for qformat in qformat_list
), (
"One or more quantization formats provided are not supported for unified checkpoint export"
)
auto_quantize(
args,
language_model,
calib_dataloader,
full_model=full_model,
auto_quantize_method=args.auto_quantize_method,
auto_quantize_score_size=args.auto_quantize_score_size,
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
constraints={"effective_bits": args.auto_quantize_bits},
quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list],
disabled_layers=default_disabled_layers,
kv_cache_quant_cfg=_cli_kv_cache_quant_cfg(),
)

else:
# mono quantization
Expand Down Expand Up @@ -1198,9 +1279,11 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--recipe",
help=(
"PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, "
"general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). "
"When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config."
"PTQ or AutoQuantize recipe YAML file or name without suffix "
"(e.g. general/ptq/nvfp4_default-kv_fp8_cast, "
"general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). "
"PTQ recipes fully own quant config; AutoQuantize recipes own search config "
"and may optionally override --kv_cache_qformat via their kv_cache field."
),
default=None,
)
Expand Down
Loading
Loading