Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions modelopt_recipes/models/Kimi-K2.5/dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model DFlash offline training recipe for Kimi-K2.5.

metadata:
recipe_type: speculative_dflash
description: DFlash offline training recipe for Kimi-K2.5.

# maps to ModelArguments (main.py)
model:
model_name_or_path: moonshotai/Kimi-K2.5
trust_remote_code: true
use_fake_base_for_offline: true

# maps to DataArguments (main.py)
data:
data_path:
offline_data_path: <path to offline data>
# Jinja chat template with {% generation %} tags for answer_only_loss.
# Required when answer_only_loss=true. Set in per-model launcher YAML.
# Templates are in modelopt_recipes/general/speculative_decoding/chat_templates/
chat_template:

# maps to TrainingArguments (main.py)
training:
# --- commonly modified ---
output_dir: ckpts/kimi-k25-dflash
num_train_epochs: 10
per_device_train_batch_size: 1
learning_rate: 6.0e-4
warmup_steps: 100
training_seq_len: 4096
logging_steps: 100
save_steps: 5000
cp_size: 1
dp_shard_size: 1
disable_tqdm: true
estimate_ar: false
ar_validate_steps: 0
answer_only_loss: true

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false
ddp_find_unused_parameters: true
ddp_timeout: 1800
report_to: tensorboard

# maps to DFlashConfig (modelopt/torch/speculative/config.py).
# dflash_mask_token_id falls back to tokenizer.mask_token_id when unset; set
# explicitly here if the tokenizer does not provide one.
dflash:
dflash_block_size: 8
dflash_num_anchors: 512
dflash_use_torch_compile: false
dflash_self_logit_distillation: true
dflash_loss_decay_factor: 4.0
dflash_architecture_config:
num_hidden_layers: 5
# sliding_window and layer_types are inherited from base model config automatically
73 changes: 73 additions & 0 deletions modelopt_recipes/models/Kimi-K2.5/eagle3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model EAGLE3 offline training recipe for Kimi-K2.5.
# Mirrors examples/speculative_decoding/scripts/train_kimi_k25_offline.sh.

metadata:
recipe_type: speculative_eagle
description: EAGLE3 offline training recipe for Kimi-K2.5.

# maps to ModelArguments (main.py)
model:
model_name_or_path: moonshotai/Kimi-K2.5
trust_remote_code: true
use_fake_base_for_offline: true

# maps to DataArguments (main.py)
data:
data_path: input_conversations/train.jsonl
offline_data_path: <path to offline data>
draft_vocab_cache:
vlm_img_dir:
vlm_processor:

# maps to TrainingArguments (main.py)
training:
# --- commonly modified ---
output_dir: ckpts/kimi-k25-eagle3
num_train_epochs: 1
per_device_train_batch_size: 1
learning_rate: 1.0e-4
warmup_steps: 1000
training_seq_len: 4096
logging_steps: 100
save_steps: 8192
cp_size: 1
disable_tqdm: false
estimate_ar: false
ar_validate_steps: -1
answer_only_loss: false

# --- rarely modified ---
do_eval: false
lr_scheduler_type: linear
save_strategy: steps
weight_decay: 0.0
dataloader_drop_last: true
bf16: true
tf32: true
remove_unused_columns: false

# maps to EagleConfig (modelopt/torch/speculative/config.py).
eagle:
# eagle_offline is derived from data.offline_data_path; do not set here.
eagle_decoder_type: kimik2
eagle_ttt_steps: 3
eagle_mix_hidden_states: false
eagle_use_torch_compile: true
eagle_self_logit_distillation: true
eagle_freeze_base_model: true
eagle_loss_decay_factor: 0.9
eagle_hidden_state_distillation: false
eagle_reuse_base_decoder: false
eagle_report_acc: true
eagle_enable_nvtx: false
# Rope scaling: disable during training (default_config.py uses rope_type=default),
# inject YaRN during export for long-context inference.
eagle_export_rope_scaling:
rope_type: yarn
factor: 32.0
original_max_position_embeddings: 2048
# overwrite to modelopt/torch/speculative/eagle/default_config.py
eagle_architecture_config: {}
32 changes: 32 additions & 0 deletions modelopt_recipes/models/Qwen3-0.6B/dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model DFlash training recipe for Qwen3-0.6B.

metadata:
recipe_type: speculative_dflash
description: DFlash training recipe for Qwen3-0.6B.

# maps to ModelArguments (main.py)
model:
model_name_or_path: Qwen/Qwen3-0.6B
trust_remote_code: false
use_fake_base_for_offline: false

# maps to DataArguments (main.py)
data:
data_path:
offline_data_path:

# maps to TrainingArguments (main.py)
training:
output_dir:
training_seq_len: 512
answer_only_loss: true

# maps to DFlashConfig (modelopt/torch/speculative/config.py).
dflash:
dflash_block_size: 8
dflash_mask_token_id: 151669 # Qwen3 tokenizer mask token
dflash_architecture_config:
num_hidden_layers: 2 # small draft for 0.6B base
30 changes: 30 additions & 0 deletions modelopt_recipes/models/Qwen3-8B/dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model DFlash training recipe for Qwen3-8B.

metadata:
recipe_type: speculative_dflash
description: DFlash training recipe for Qwen3-8B.

# maps to ModelArguments (main.py)
model:
model_name_or_path: Qwen/Qwen3-8B
trust_remote_code: false
use_fake_base_for_offline: false

# maps to DataArguments (main.py)
data:
data_path:
offline_data_path:

# maps to TrainingArguments (main.py)
training:
output_dir:
training_seq_len: 4096

# maps to DFlashConfig (modelopt/torch/speculative/config.py).
dflash:
dflash_block_size: 16
dflash_loss_decay_factor: 7.0 # paper Eq.4: gamma=7 pairs with block_size=16
dflash_mask_token_id: 151669 # Qwen3 tokenizer mask token
30 changes: 30 additions & 0 deletions modelopt_recipes/models/Qwen3-8B/eagle3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model EAGLE3 training recipe for Qwen3-8B.
# Used by both online and offline EAGLE3 pipelines; the launcher YAML supplies
# data.data_path (online) or data.offline_data_path (offline).

metadata:
recipe_type: speculative_eagle
description: EAGLE3 training recipe for Qwen3-8B.

# maps to ModelArguments (main.py)
model:
model_name_or_path: Qwen/Qwen3-8B
trust_remote_code: false
use_fake_base_for_offline: false

# maps to DataArguments (main.py)
data:
data_path:
offline_data_path:

# maps to TrainingArguments (main.py)
training:
output_dir:
training_seq_len: 4096

# maps to EagleConfig (modelopt/torch/speculative/config.py).
# Qwen3 uses the llama-family decoder, which is the EagleConfig default.
eagle: {}
38 changes: 38 additions & 0 deletions modelopt_recipes/models/Qwen3.5-4B/dflash.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Per-model DFlash training recipe for Qwen3.5-4B.
#
# NOTE: Qwen3.5-4B has non-standard head_dim=160. The draft model overrides the
# attention architecture (32 heads, head_dim=128) for vLLM KV cache compatibility.

metadata:
recipe_type: speculative_dflash
description: DFlash training recipe for Qwen3.5-4B (head_dim workaround for vLLM).

# maps to ModelArguments (main.py)
model:
model_name_or_path: Qwen/Qwen3.5-4B
trust_remote_code: false
use_fake_base_for_offline: false

# maps to DataArguments (main.py)
data:
data_path:
offline_data_path:

# maps to TrainingArguments (main.py)
training:
output_dir:
training_seq_len: 4096

# maps to DFlashConfig (modelopt/torch/speculative/config.py).
dflash:
dflash_mask_token_id: 248070 # Qwen3.5 tokenizer mask token (different from Qwen3)
dflash_architecture_config:
# Override base head_dim=160 to head_dim=128 for vLLM KV cache compatibility.
num_attention_heads: 32
num_key_value_heads: 8
head_dim: 128
intermediate_size: 9728
rope_theta: 10000000
21 changes: 5 additions & 16 deletions tests/regression/torch/speculative/test_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,23 @@
import pytest
from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command

DFLASH_YAML = str(
MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "dflash.yaml"
)
DFLASH_YAML = str(MODELOPT_ROOT / "modelopt_recipes" / "models" / "Qwen3-0.6B" / "dflash.yaml")

CHAT_TEMPLATE = str(
MODELOPT_ROOT
/ "tools"
/ "launcher"
/ "examples"
/ "Qwen"
/ "Qwen3-0.6B"
/ "chat_template_train.jinja"
MODELOPT_ROOT / "modelopt_recipes" / "models" / "Qwen3-0.6B" / "chat_template_train.jinja"
)

SYNTH_DATA = str(MODELOPT_ROOT / "examples" / "dataset" / "synthetic_conversations_1k.jsonl")

# Match tools/launcher/examples/Qwen/Qwen3-0.6B/hf_online_dflash.yaml
# Match tools/launcher/examples/Qwen/Qwen3-0.6B/hf_online_dflash.yaml. Model-specific
# settings (block_size, mask_token_id, training_seq_len, answer_only_loss, draft
# num_hidden_layers) live in the per-model recipe at DFLASH_YAML.
_DFLASH_OVERRIDES = [
f"data.data_path={SYNTH_DATA}",
f"data.chat_template={CHAT_TEMPLATE}",
"training.training_seq_len=512",
"training.per_device_train_batch_size=2",
"training.logging_steps=100",
"training.answer_only_loss=true",
"dflash.dflash_block_size=8",
"dflash.dflash_mask_token_id=151669",
"dflash.dflash_use_torch_compile=False",
"dflash.dflash_architecture_config.num_hidden_layers=2",
]


Expand Down
20 changes: 5 additions & 15 deletions tests/regression/torch/speculative/test_dflash_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,23 @@
import pytest
from _test_utils.examples.run_command import MODELOPT_ROOT, run_example_command

DFLASH_YAML = str(
MODELOPT_ROOT / "modelopt_recipes" / "general" / "speculative_decoding" / "dflash.yaml"
)
DFLASH_YAML = str(MODELOPT_ROOT / "modelopt_recipes" / "models" / "Qwen3-0.6B" / "dflash.yaml")

CHAT_TEMPLATE = str(
MODELOPT_ROOT
/ "tools"
/ "launcher"
/ "examples"
/ "Qwen"
/ "Qwen3-0.6B"
/ "chat_template_train.jinja"
MODELOPT_ROOT / "modelopt_recipes" / "models" / "Qwen3-0.6B" / "chat_template_train.jinja"
)

SYNTH_DATA = str(MODELOPT_ROOT / "examples" / "dataset" / "synthetic_conversations_1k.jsonl")

# Match _DFLASH_OVERRIDES in test_dflash.py so the offline run is comparable to online.
# Model-specific settings live in DFLASH_YAML; only env-/run-specific knobs go here.
# logging_steps is overridden lower than the online test so the shorter offline run
# still produces multiple log entries.
_DFLASH_OVERRIDES = [
f"data.chat_template={CHAT_TEMPLATE}",
"training.training_seq_len=512",
"training.per_device_train_batch_size=2",
"training.logging_steps=50",
"training.answer_only_loss=true",
"dflash.dflash_block_size=8",
"dflash.dflash_mask_token_id=151669",
"dflash.dflash_use_torch_compile=False",
"dflash.dflash_architecture_config.num_hidden_layers=2",
]

# Number of conversations to dump. Smaller than the full 1K to keep dump time
Expand Down
9 changes: 2 additions & 7 deletions tools/launcher/examples/Qwen/Qwen3-0.6B/hf_online_dflash.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,17 @@ pipeline:
task_0:
script: common/specdec/dflash_online_training.sh
args:
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml
- --config modules/Model-Optimizer/modelopt_recipes/models/Qwen3-0.6B/dflash.yaml
- model.model_name_or_path=<<global_vars.hf_model>>
- data.data_path=modules/Model-Optimizer/examples/dataset/synthetic_conversations_1k.jsonl
- data.chat_template=examples/Qwen/Qwen3-0.6B/chat_template_train.jinja
- data.chat_template=modules/Model-Optimizer/modelopt_recipes/models/Qwen3-0.6B/chat_template_train.jinja
- training.output_dir=/scratchspace/dflash_qwen3_0.6b
- training.num_train_epochs=3
- training.training_seq_len=512
- training.per_device_train_batch_size=2
- training.save_steps=500
- training.logging_steps=100
- training.disable_tqdm=true
- training.answer_only_loss=true
- dflash.dflash_block_size=8
- dflash.dflash_mask_token_id=151669
- dflash.dflash_use_torch_compile=False
- dflash.dflash_architecture_config.num_hidden_layers=2
environment:
- MAX_FINAL_LOSS: "2.0"
- MIN_FINAL_ACC: "0.40"
Expand Down
3 changes: 1 addition & 2 deletions tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ pipeline:
task_2:
script: common/eagle3/train_eagle.sh
args:
- --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml
- --config modules/Model-Optimizer/modelopt_recipes/models/Qwen3-8B/eagle3.yaml
- model.model_name_or_path=<<global_vars.hf_model>>
- data.offline_data_path=/scratchspace/offline_hidden_states
- training.output_dir=/scratchspace/eagle3
- training.training_seq_len=4096
- training.disable_tqdm=true
- training.ar_validate_steps=500000
slurm_config:
Expand Down
Loading
Loading