Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changelog
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Add ``DATASET_COMBOS`` to ``modelopt.torch.utils.dataset_utils`` — single ``--dataset`` tokens that fan out to multiple registered datasets; per-entry ``num_samples`` is split evenly across the members. Initial combos: ``cnn_nemotron_v2_mix`` (``cnn_dailymail`` + ``nemotron-post-training-dataset-v2``, used by ``hf_ptq.py`` when no ``--dataset`` is provided) and ``nemotron-post-training-v3`` (the seven ``nvidia/Nemotron-*`` SFT datasets added in #1498, mirroring the `nemotron-post-training-v3 collection <https://huggingface.co/collections/nvidia/nemotron-post-training-v3>`_). Combo names are listed by ``get_supported_datasets()`` and surfaced in ``--dataset`` help. ``get_dataset_dataloader`` rejects inputs that mix a combo with one of its member datasets (e.g. ``cnn_dailymail,cnn_nemotron_v2_mix``) to avoid double-sampling, and ``get_dataset_samples`` rejects combo names so callers route through the dataloader. ``hf_ptq.py`` default ``--calib_size`` is bumped from ``512`` to ``1024`` so the total calibration sample count under the new default combo matches the previous two-dataset fallback.
- The ``nemotron-sft-agentic-v2`` registered dataset (added in #1498) now uses only the ``search`` split. The previously configured ``interactive_agent`` and ``tool_calling`` splits contain content-level defects (heterogeneous schema and a malformed JSON row, respectively) that cause pyarrow's streaming JSON reader to fail deterministically.
- Add shared Megatron-Core calibration forward loop: ``modelopt.torch.utils.plugins.megatron_calibration.get_megatron_calibration_forward_loop`` produces the ``forward_loop`` callable expected by ``mtq.quantize`` / ``mtp.prune``. Replaces the bespoke calibration loops in Megatron-LM and Megatron-Bridge for quantization and pruning with a single canonical implementation.
- Add ``pack=True`` mode to ``get_dataset_dataloader`` (Megatron-LM pretraining-style global-stream document packing): all raw samples concatenated EOS-separated into one token stream, sliced into uniform ``max_sample_length`` rows. Used by the shared megatron calibration loop.

**Bug Fixes**

Expand Down
2 changes: 1 addition & 1 deletion examples/megatron_bridge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ torchrun --nproc_per_node 2 prune_minitron.py \
--hf_model_name_or_path Qwen/Qwen3-8B \
--prune_target_memory_mb 12288 \
--seq_length 4096 \
--calib_mbs 1 \
--calib_batch_size 1 \
--output_hf_path /tmp/Qwen3-8B-Pruned-12GB
```

Expand Down
34 changes: 14 additions & 20 deletions examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@
import modelopt.torch.prune as mtp
import modelopt.torch.utils.distributed as dist
from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0
from modelopt.torch.utils.plugins.mbridge import (
get_hf_mbridge_calibration_loop,
load_mbridge_model_from_hf,
)
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
from modelopt.torch.utils.plugins.megatron_calibration import get_megatron_calibration_forward_loop
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu


Expand Down Expand Up @@ -104,11 +102,7 @@ def get_args() -> argparse.Namespace:
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
)
# TODO: Add support for pre-training dataset (pre-tokenized)
# TODO: only allow mbs>1 for pretraining dataset
parser.add_argument(
"--calib_mbs", type=int, default=1, choices=[1], help="Calibration micro-batch size"
)
parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size")
parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size")
parser.add_argument("--seq_length", type=int, default=4096)
# Pruning parameters
parser.add_argument(
Expand Down Expand Up @@ -164,8 +158,8 @@ def get_args() -> argparse.Namespace:
default=None,
help=(
"Batch size used only for KV-cache sizing in --prune_target_memory_mb. "
"Defaults to --calib_mbs when not set. "
"Use this to target an inference batch size that differs from the calibration micro-batch size."
"Defaults to --calib_batch_size when not set. "
"Use this to target an inference batch size that differs from the calibration batch size."
),
)

Expand Down Expand Up @@ -296,16 +290,14 @@ def main(args: argparse.Namespace):
init_model_parallel=True,
moe_grouped_gemm=False,
)
forward_loop = get_hf_mbridge_calibration_loop(
model=model,
provider=provider,
tokenizer=tokenizer,
hf_model_name_or_path=args.hf_model_name_or_path,
trust_remote_code=args.trust_remote_code,
forward_loop = get_megatron_calibration_forward_loop(
tokenizer,
dataset_name=args.calib_dataset_name,
num_samples=args.calib_num_samples,
micro_batch_size=args.calib_mbs,
global_batch_size=args.calib_gbs,
seq_length=args.seq_length,
batch_size=args.calib_batch_size,
# pack=True uses Megatron pretraining-style global-stream document packing
pack=True,
)

pruning_config = {
Expand Down Expand Up @@ -385,7 +377,9 @@ def score_func(m):
pruning_config["top_k"] = args.top_k
# memory_mb constraint requires batch_size and seq_length
pruning_config["batch_size"] = (
args.inference_batch_size if args.inference_batch_size is not None else args.calib_mbs
args.inference_batch_size
if args.inference_batch_size is not None
else args.calib_batch_size
)
Comment thread
kevalmorabia97 marked this conversation as resolved.
pruning_config["seq_length"] = args.seq_length
print_rank_0(f"Pruning constraints: {pruning_constraints}")
Expand Down
14 changes: 6 additions & 8 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ Please see example snippets of both modes for Minitron pruning on Megatron-Bridg
```python
import torch
import modelopt.torch.prune as mtp
from modelopt.torch.utils.plugins.mbridge import (
get_hf_mbridge_calibration_loop,
load_mbridge_model_from_hf,
from modelopt.torch.utils.plugins.mbridge import load_mbridge_model_from_hf
from modelopt.torch.utils.plugins.megatron_calibration import (
get_megatron_calibration_forward_loop,
)

# Import the Megatron-Bridge Qwen3-8B model from Hugging Face checkpoint
Expand All @@ -67,13 +67,11 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf
)

# Set up the forward loop to run on 1024 train samples
forward_loop = get_hf_mbridge_calibration_loop(
model=model,
provider=provider,
tokenizer=tokenizer,
hf_model_name_or_path="Qwen/Qwen3-8B",
forward_loop = get_megatron_calibration_forward_loop(
tokenizer,
dataset_name="nemotron-post-training-dataset-v2",
num_samples=1024,
seq_length=4096,
)

# Run pruning on the unwrapped model
Expand Down
103 changes: 98 additions & 5 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import json
import os
import random
from collections.abc import Callable, Iterator
from contextlib import contextmanager, suppress
from pathlib import Path
Expand Down Expand Up @@ -557,15 +558,61 @@ def __len__(self):
return len(next(iter(self.encodings.values())))


def _pack_documents_into_rows(
samples: list[str], tokenizer: "PreTrainedTokenizerBase", seq_length: int, num_rows: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Global-stream document packing (Megatron-LM pretraining style).

Concatenate all raw samples into one EOS-separated token stream, then slice
the stream into uniform-length rows. Rows can (and usually do) start mid-doc —
this matches the distribution Megatron's blended-dataset pretraining uses with
``.bin``/``.idx`` files, so the trained model has seen this pattern extensively.

Returns ``(input_ids, attention_mask)`` tensors of shape ``(num_rows, seq_length)``.
Non-final rows are fully real tokens (mask=1 throughout). The final partial row
(when the stream runs out before reaching ``num_rows``) has mask=1 over the real
tail and mask=0 over trailing pad.
"""
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
has_eos_sep = eos_id is not None
token_stream: list[int] = []
for s in samples:
token_stream.extend(tokenizer.encode(s, add_special_tokens=False))
if has_eos_sep:
token_stream.append(eos_id)
if len(token_stream) >= num_rows * seq_length:
break

n_full = min(num_rows, len(token_stream) // seq_length)
rows_ids: list[list[int]] = [
token_stream[i * seq_length : (i + 1) * seq_length] for i in range(n_full)
]
rows_masks: list[list[int]] = [[1] * seq_length for _ in range(n_full)]
# Trailing partial row (if any remain in the num_rows budget).
if n_full < num_rows and len(token_stream) > n_full * seq_length:
tail = token_stream[n_full * seq_length :]
real_len = len(tail)
tail.extend([pad_id] * (seq_length - real_len))
rows_ids.append(tail)
rows_masks.append([1] * real_len + [0] * (seq_length - real_len))

return (
torch.tensor(rows_ids, dtype=torch.long),
torch.tensor(rows_masks, dtype=torch.long),
)


def get_dataset_dataloader(
dataset_name: str | list[str] = "cnn_dailymail",
tokenizer: "PreTrainedTokenizerBase | None" = None,
batch_size: int = 1,
num_samples: int | list[int] = 512,
max_sample_length: int = 512,
device: torch.device | None = None,
device: torch.device | str | None = None,
include_labels: bool = False,
apply_chat_template: bool = False,
pack: bool = False,
) -> DataLoader:
"""Get a dataloader with the dataset name and tokenizer of the target model.

Expand All @@ -576,12 +623,25 @@ def get_dataset_dataloader(
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
max_sample_length: Maximum length of a sample.
num_samples: Number of samples from the dataset (interpreted as number of *output
rows* in both ``pack=False`` and ``pack=True`` modes — in packed mode the
loader oversamples raw text 4x to ensure enough docs to fill all rows).
max_sample_length: Maximum length of a sample (or per-row length under ``pack=True``).
device: Target device for the returned dataloader.
include_labels: Whether to include labels in the dataloader.
include_labels: Whether to include labels in the dataloader (ignored when
``pack=True``).
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset).
pack: If True, use global-stream document packing (Megatron-LM pretraining
style): all raw samples are concatenated into one EOS-separated token
stream and sliced into uniform-length rows. Rows can (and usually do)
start mid-document — this matches the distribution Megatron's blended
``.bin``/``.idx`` pretraining uses, so the trained model has seen this
pattern extensively. Non-final rows are fully real tokens (no pad); only
the trailing partial row (when the stream runs out before reaching
``num_samples`` rows) is padded. Default ``False`` for backwards-compatibility
with the prior one-doc-per-row tokenize-and-pad behavior; calibration
callers should pass ``True``.

Returns:
An instance of dataloader.
Expand Down Expand Up @@ -633,13 +693,46 @@ def get_dataset_dataloader(
expanded_num_samples.append(n)
dataset_name, num_samples = expanded_names, expanded_num_samples

# Sample count semantics:
# - pack=False: gather exactly `num_sample` raw docs per source, one per output row.
# - pack=True: oversample 8x per source to ensure enough raw docs to fill all rows,
# since each row greedily packs multiple docs.
sample_multiplier = 8 if pack else 1
all_samples = []
for ds_name, num_sample in zip(dataset_name, num_samples):
samples = get_dataset_samples(
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
ds_name,
num_sample * sample_multiplier,
apply_chat_template=apply_chat_template,
tokenizer=tokenizer,
)
all_samples.extend(samples)

# Multi-source pack=True without shuffling would consume all of oversampled source 1's docs
# before any of oversampled source 2 are reached
if pack and len(dataset_name) > 1:
random.Random(0).shuffle(all_samples)

if pack:
total_rows = sum(num_samples)
input_ids, attention_mask = _pack_documents_into_rows(
all_samples, tokenizer, max_sample_length, total_rows
)
if input_ids.shape[0] < total_rows:
warn(
f"pack=True produced {input_ids.shape[0]} rows out of {total_rows} "
f"requested — raw text exhausted before filling all rows (8x oversample "
f"of num_samples was insufficient). Increase `num_samples` or shorten "
f"`max_sample_length`."
)
if device:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
tokenized_dataset = _CustomDataset(
{"input_ids": input_ids, "attention_mask": attention_mask}
)
return DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=False)

batch_encoded = tokenizer(
all_samples,
return_tensors="pt",
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/utils/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from modelopt.torch.utils import import_plugin

with import_plugin("megatron_calibration"):
from .megatron_calibration import *

with import_plugin("megatron_generate"):
from .megatron_generate import *

Expand Down
Loading
Loading