diff --git a/tests/engine/test_dense_train_engine.py b/tests/engine/test_dense_train_engine.py index f252827504..88bff9e0f7 100644 --- a/tests/engine/test_dense_train_engine.py +++ b/tests/engine/test_dense_train_engine.py @@ -83,13 +83,13 @@ def warmup_fn(x): seq_ctx = seq_ctx.split(sequence_parallel_mesh=sp_mesh) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) seq_ctx = seq_ctx_list[0] loss_ctx = loss_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/engine/test_moe_train_engine.py b/tests/engine/test_moe_train_engine.py index 1dc8685096..4302aa160d 100644 --- a/tests/engine/test_moe_train_engine.py +++ b/tests/engine/test_moe_train_engine.py @@ -93,12 +93,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -184,12 +184,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/engine/test_moe_train_engine_float8.py b/tests/engine/test_moe_train_engine_float8.py index 15ea4b6730..7246b25143 100644 --- a/tests/engine/test_moe_train_engine_float8.py +++ b/tests/engine/test_moe_train_engine_float8.py @@ -87,12 +87,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -165,12 +165,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] loss_log = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -264,12 +264,12 @@ def warmup_fn(x): seq_ctx.to('cuda') seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] logs_info = engine.train_step(engine_input)["logs_info"] grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/loss/test_ce_loss.py b/tests/loss/test_ce_loss.py index 863eae4da2..195bc08f42 100644 --- a/tests/loss/test_ce_loss.py +++ b/tests/loss/test_ce_loss.py @@ -72,7 +72,7 @@ def test_global_loss_reduction(self, loss_mode, grad_accumulation_steps, chunk_s for data in data_batch: seq_ctx = data["seq_ctx"] seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None) loss_ctx_list.append(loss_ctx) loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list]) @@ -172,7 +172,7 @@ def test_other_loss_reduction(self, loss_reduction, loss_mode, grad_accumulation for data in data_batch: seq_ctx = data["seq_ctx"] seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None) loss_ctx_list.append(loss_ctx) loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list]) @@ -310,7 +310,7 @@ def test_sp_global_loss_reduction(self, loss_mode, sp_size, grad_accumulation_st sp_mesh = data_mesh['sp'] seq_ctx.sequence_parallel_mesh = sp_mesh seq_ctx_list = [seq_ctx] - loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] if sp_size > 1: seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh) @@ -397,7 +397,7 @@ def test_sp_others_loss_reduction(self, loss_reduction, loss_mode, sp_size, grad sp_mesh = data_mesh['sp'] seq_ctx.sequence_parallel_mesh = sp_mesh seq_ctx_list = [seq_ctx] - loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] if sp_size > 1: seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh) diff --git a/tests/loss/test_grpo_loss.py b/tests/loss/test_grpo_loss.py index 0ee2378941..7a1747a125 100644 --- a/tests/loss/test_grpo_loss.py +++ b/tests/loss/test_grpo_loss.py @@ -147,7 +147,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size, if sp_size > 1: seq_ctx = seq_ctx.split(sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels_list_rank[iter_idx], advantages=advantages_list_rank[iter_idx], sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh) loss_ctx_list.append(loss_ctx) with torch.no_grad(): diff --git a/tests/loss/test_oreal_loss.py b/tests/loss/test_oreal_loss.py index 1e50b2e488..2ceae417d0 100644 --- a/tests/loss/test_oreal_loss.py +++ b/tests/loss/test_oreal_loss.py @@ -216,8 +216,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size, seq_ctx = seq_ctx.split(sp_mesh) seq_ctx_list.append(seq_ctx) loss_ctx = loss_cfg.build( - shifted_labels=shifted_labels_list_rank[iter_idx], - advantages=advantages_list_rank[iter_idx], + data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh, ) loss_ctx_list.append(loss_ctx) diff --git a/tests/model/test_gpt_oss_moe.py b/tests/model/test_gpt_oss_moe.py index d0f011bc46..edabb9ea02 100644 --- a/tests/model/test_gpt_oss_moe.py +++ b/tests/model/test_gpt_oss_moe.py @@ -78,7 +78,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -87,7 +87,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class with torch.no_grad(): output = gpt_oss_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -141,7 +141,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -152,7 +152,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size): with torch.no_grad(): output = gpt_oss_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2)) diff --git a/tests/model/test_intern_s1.py b/tests/model/test_intern_s1.py index 00f96c8b22..db76848f25 100644 --- a/tests/model/test_intern_s1.py +++ b/tests/model/test_intern_s1.py @@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol): seq_ctx_list = [seq_ctx] loss_cfg = CELossConfig() LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol): seq_ctx_list = [seq_ctx] loss_cfg = CELossConfig() LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol): with torch.no_grad(): output = interns1_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) diff --git a/tests/model/test_moe.py b/tests/model/test_moe.py index e34b58bd6f..c7c1c31374 100644 --- a/tests/model/test_moe.py +++ b/tests/model/test_moe.py @@ -62,12 +62,12 @@ def test_moe_config(self, dtype, device): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) + model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx}) class TestDistributedMoE(DeterministicDDPTestCase): @@ -135,15 +135,15 @@ def test_parallel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"] + loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"] - loss_expected = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"] + loss_expected = model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"] torch.allclose(loss_expected, loss_parallel, atol=1e-6, rtol=1e-4) diff --git a/tests/model/test_qwen3_5.py b/tests/model/test_qwen3_5.py index 0431d36403..761d301b85 100644 --- a/tests/model/test_qwen3_5.py +++ b/tests/model/test_qwen3_5.py @@ -9,11 +9,13 @@ import torch.distributed as dist from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config from xtuner.v1.loss.ce_loss import CELossConfig -from xtuner.v1.model.moe.moe import SequenceContext +from xtuner.v1.model.moe.moe import SequenceContext, MTPConfig from xtuner.v1.utils.test_utils import init_data_mesh from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig from xtuner.v1.config import FSDPConfig from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh +from xtuner.v1.data_proto.utils import pad_to_multiple_of + import tempfile from pathlib import Path @@ -82,13 +84,28 @@ def _forward(self, model, type, device, sp_size): position_ids = tokenized_data['position_ids'].cuda() else: tokenizer = AutoTokenizer.from_pretrained(QWEN3_VL_MOE_PATH) - input_ids = tokenizer(f"今天天气不错,是学习的好日子。请听题: 1+1 等于多少?", - return_tensors="pt").input_ids.to(device) - labels = input_ids.clone() + tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_VL_MOE_PATH, rand_video_max_frames=14, + add_vision_id=True).build(tokenizer) + raw_data = { + "id": 3, "messages": [ + { + "role": "user", "content": [ + { + "type": "text", + "text": "Translate this into chinese: Where my eyes gaze, only memories remain; where my heart strays, only yesterday's pain; where my sight stays, only regret's refrain." + } + ] + }, + {"role": "assistant", "content": "目之所及,唯余旧忆;心之所向,唯余昨痛;眸之所留,唯余悔声。"} + ] + } + tokenized_data = tokenize_fn(raw_data) + input_ids = torch.tensor(tokenized_data['input_ids'])[None].cuda() + labels = torch.tensor(tokenized_data['labels'])[None].cuda() pixel_values = None image_grid_thw = None position_ids = None - + from transformers import Qwen3_5MoeForConditionalGeneration is_hf_model = isinstance(model, Qwen3_5MoeForConditionalGeneration) @@ -115,10 +132,8 @@ def _forward(self, model, type, device, sp_size): dist.all_reduce(output.loss.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) return output.loss else: - loss_cfg = CELossConfig() - - shift_input_ids = input_ids[:, :-1] - shifted_labels = labels[:, 1:] + shift_input_ids = pad_to_multiple_of(input_ids[:, :-1], padding_value=0, multiple_of=sp_size) + shifted_labels = pad_to_multiple_of(labels[:, 1:], padding_value=-100, multiple_of=sp_size) if position_ids is not None: position_ids = position_ids[..., :-1] @@ -127,22 +142,18 @@ def _forward(self, model, type, device, sp_size): data_mesh = init_data_mesh(device, sp_size=sp_size) sp_mesh = data_mesh["sp"] - seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) + seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to("cuda"),)) seq_ctx.image_grid_thw = image_grid_thw seq_ctx.pixel_values = pixel_values if position_ids is not None: seq_ctx.position_ids = position_ids - seq_ctx.to('cuda') + seq_ctx.to("cuda") if sp_size > 1: seq_ctx = seq_ctx.split(sp_mesh) - seq_ctx_list = [seq_ctx] - LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) - loss_ctx_list = [loss_ctx] - loss_ctx_list = LossContext.build_batches(loss_ctx_list) - loss_ctx = loss_ctx_list[0] - seq_ctx = seq_ctx_list[0] + data_batch = [{"seq_ctx": seq_ctx, "shifted_labels": shifted_labels}] + loss_ctx_batch = model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh) + loss_ctx = loss_ctx_batch[0] with torch.no_grad(): output = model( @@ -221,6 +232,63 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol): self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol)) self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol)) + @parametrize.parametrize( + "device,sp_size,tol", + [ + ("cuda", 1, 1e-2), + ("cuda", 4, 1e-2), + ], + ) + def test_qwen3_5_vl_run_mtp(self, device, sp_size, tol): + self.create_pg(device) + loss_reference = { + "text": 1.5416, + "image": 3.6920, + "video": 8.2165, + } + + QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"] + + torch.cuda.empty_cache() + + with torch.device("meta"): + model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False) + model_cfg.text_config.mtp_config = MTPConfig(num_layers=1, loss_scaling_factor=1) + qwen3vl_model = model_cfg.build().to(torch.bfloat16) + + qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH) + qwen3vl_model.eval() + + losses = {} + + loss_xtuner_text = self._forward(qwen3vl_model, type="text", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_text), "MTP text loss should not be NaN") + + loss_xtuner_image = self._forward(qwen3vl_model, type="image", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_image), "MTP image loss should not be NaN") + + loss_xtuner_video = self._forward(qwen3vl_model, type="video", device=device, sp_size=sp_size) + self.assertFalse(torch.isnan(loss_xtuner_video), "MTP video loss should not be NaN") + + losses["text"] = loss_xtuner_text + losses["image"] = loss_xtuner_image + losses["video"] = loss_xtuner_video + + for key, loss in losses.items(): + self.assertTrue( + torch.allclose( + loss, torch.tensor( + loss_reference[key], + device=loss_xtuner_text.device, + dtype=loss_xtuner_text.dtype + ), + atol=tol, + rtol=tol + ), + f"Expected text loss around {key}, but got {loss.item()}" + ) + + @parametrize.parametrize( "device,sp_size", [ @@ -233,6 +301,7 @@ def test_save_hf_with_mtp(self, device, sp_size): with torch.device("meta"): model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False) + model_cfg.text_config.mtp_config = MTPConfig(num_layers=1) qwen3vl_model = model_cfg.build().to(torch.bfloat16) fsdp_config = FSDPConfig(cpu_offload=False) @@ -262,8 +331,6 @@ def test_save_hf_with_mtp(self, device, sp_size): # Verify all original HF weights are preserved correctly for key in origin_index["weight_map"].keys(): - if "mtp" in key: - continue # TODO: remove this after MTP is implemented origin_safetensor_name = origin_index["weight_map"][key] saved_safetensor_name = saved_index["weight_map"][key] diff --git a/tests/model/test_qwen3_dense.py b/tests/model/test_qwen3_dense.py index efe327cb90..94aded6805 100644 --- a/tests/model/test_qwen3_dense.py +++ b/tests/model/test_qwen3_dense.py @@ -64,7 +64,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -75,7 +75,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol)) @@ -120,7 +120,7 @@ def test_fsdp_accuracy(self, device, tp_size): seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -131,7 +131,7 @@ def test_fsdp_accuracy(self, device, tp_size): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2)) @@ -196,7 +196,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -207,7 +207,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) assert "loss" in output diff --git a/tests/model/test_qwen3_moe.py b/tests/model/test_qwen3_moe.py index 4113be5a83..dbf5e7366d 100644 --- a/tests/model/test_qwen3_moe.py +++ b/tests/model/test_qwen3_moe.py @@ -98,7 +98,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod loss_cfg = CELossConfig(mode=loss_mode) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -107,7 +107,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_mod with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] losses.append(loss) @@ -181,7 +181,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type): loss_cfg = CELossConfig() seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -190,7 +190,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size, model_type): with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) loss = output["loss"] losses.append(loss) @@ -257,7 +257,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi seq_ctx = SequenceContext.from_input_ids(input_ids=(shift_input_ids.to('cuda'),)) seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] @@ -268,7 +268,7 @@ def test_sliding_windows(self, use_sliding_window, max_window_layers, sliding_wi with torch.no_grad(): output = qwen_model( seq_ctx=seq_ctx, - loss_ctx=loss_ctx, + loss_ctx={"lm": loss_ctx}, ) assert "loss" in output diff --git a/tests/model/test_qwen3_tile_embedding.py b/tests/model/test_qwen3_tile_embedding.py index 79869da87a..b931395c5a 100644 --- a/tests/model/test_qwen3_tile_embedding.py +++ b/tests/model/test_qwen3_tile_embedding.py @@ -78,12 +78,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] engine.train_step(engine_input) grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) @@ -153,12 +153,12 @@ def warmup_fn(x): seq_ctx.num_padding = pack_len seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None) + loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] seq_ctx = seq_ctx_list[0] - engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)] + engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})] engine.train_step(engine_input) grad_norm = engine.clip_grad_norm() engine.step_optimizer(grad_norm) diff --git a/tests/model/test_qwen3_vl.py b/tests/model/test_qwen3_vl.py index f6027e8799..4c6dc7ab47 100644 --- a/tests/model/test_qwen3_vl.py +++ b/tests/model/test_qwen3_vl.py @@ -139,7 +139,7 @@ def _test_all(self, hf_model, qwen3vl_model, type, device, sp_size, tol): seq_ctx_list = [seq_ctx] LossContext = loss_cfg.loss_ctx_cls - loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh) loss_ctx_list = [loss_ctx] loss_ctx_list = LossContext.build_batches(loss_ctx_list) loss_ctx = loss_ctx_list[0] diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 1dcc81a5b0..4295dd54ed 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -112,6 +112,7 @@ def create_pg(self, device): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_save_hf_interval(self): """Test save_hf is called at correct intervals during training.""" @@ -184,6 +185,7 @@ def test_save_hf_interval(self): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_save_checkpoint_interval(self): self.create_pg(DEVICE) @@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self): @patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True)) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) + @patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @prepare def test_resume(self): self.create_pg(DEVICE) @@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch): assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1 +@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[])) @patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine())) def test_resume_and_load_checkpoint_cfg(tmp_path: Path): # 0. prepare environment diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index 69da8d045a..5a6722b8d0 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -5,7 +5,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Self -from .utils import pad_to_multiple_of, split_for_sequence_parallel +from .utils import gather_for_sequence_parallel, pad_to_multiple_of, split_for_sequence_parallel # Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above @@ -50,6 +50,12 @@ class SequenceContext: # moe routed_experts rollout_routed_experts: torch.Tensor | None + # Private backing attributes for SP shard reconstruction + _raw_input_ids: torch.LongTensor | None + _raw_inputs_embeds: torch.FloatTensor | None + _shard_start: int + _shard_size: int + def __init__( self, input_ids: torch.LongTensor | None, # shape (1, seq_len) @@ -71,6 +77,11 @@ def __init__( inputs_embeds: torch.FloatTensor | None = None, num_img_tokens: list[list[int]] | None = None, rollout_routed_experts: torch.Tensor | None = None, + # SP shard metadata: private, accessed via properties below + raw_input_ids: torch.LongTensor | None = None, + raw_inputs_embeds: torch.FloatTensor | None = None, + shard_start: int = 0, + shard_size: int = 0, ): # Only to distinguish parameters accepted by the constructor from attributes. For example, for `max_length_q`, # the argument can be an int, but as an attribute it can only be a tensor @@ -99,6 +110,10 @@ def __init__( self.inputs_embeds = inputs_embeds self.num_img_tokens = num_img_tokens self.rollout_routed_experts = rollout_routed_experts + self._raw_input_ids = raw_input_ids + self._raw_inputs_embeds = raw_inputs_embeds + self._shard_start = shard_start + self._shard_size = shard_size self.seq_idx = None seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] @@ -169,6 +184,7 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: start = sp_input_ids.shape[1] * sequence_parallel_mesh.get_local_rank() end = start + sp_input_ids.shape[1] sp_num_padding = max(0, min(sp_input_ids.shape[1], end - num_non_padding)) + shard_size = sp_input_ids.shape[1] if self.position_ids is not None: pad_position_ids = pad_to_multiple_of(self.position_ids, 0, multiple_of, -1) @@ -205,6 +221,9 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self: inputs_embeds=self.inputs_embeds, num_img_tokens=self.num_img_tokens, rollout_routed_experts=self.rollout_routed_experts, + raw_input_ids=cast(torch.LongTensor, pad_input_ids), + shard_start=start, + shard_size=shard_size, ) return sp_seq_ctx else: @@ -308,6 +327,71 @@ def seq_lens_q(self) -> torch.LongTensor: def seq_lens_k(self) -> torch.LongTensor: return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore + @property + def raw_input_ids(self) -> torch.LongTensor | None: + """Full (un-split) input_ids across all SP ranks. + + In non-SP mode, returns ``input_ids`` directly. In SP mode, returns the + pre-stored full tensor if available; otherwise triggers an allgather and + caches the result for subsequent calls. + + Returns: + torch.LongTensor | None: The full input_ids tensor, or ``None`` if + ``input_ids`` is ``None``. + """ + if self._raw_input_ids is not None: + return self._raw_input_ids + if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1: + return self.input_ids + assert self.input_ids is not None + gathered = gather_for_sequence_parallel( + self.input_ids, dim=1, sp_group=self.sequence_parallel_mesh.get_group() + ) + self._raw_input_ids = cast(torch.LongTensor, gathered) + return self._raw_input_ids + + @property + def raw_inputs_embeds(self) -> torch.FloatTensor | None: + """Full (un-split) inputs_embeds across all SP ranks. + + In non-SP mode, returns ``inputs_embeds`` directly. In SP mode, triggers + a single allgather on first access and caches the result for subsequent + calls, so the communication cost is paid at most once. + + Returns: + torch.FloatTensor | None: The full inputs_embeds tensor, or ``None`` if + ``inputs_embeds`` is ``None``. + """ + if self._raw_inputs_embeds is not None: + return self._raw_inputs_embeds + if self.inputs_embeds is None: + return None + if self.sequence_parallel_mesh is None or self.sequence_parallel_mesh.size() == 1: + return self.inputs_embeds + gathered = gather_for_sequence_parallel( + self.inputs_embeds, dim=1, sp_group=self.sequence_parallel_mesh.get_group() + ) + self._raw_inputs_embeds = cast(torch.FloatTensor, gathered) + return self._raw_inputs_embeds + + @property + def raw_position_ids(self) -> torch.LongTensor | None: + """Full (un-split) position_ids across all SP ranks. + + Returns: + torch.LongTensor | None: The full position_ids tensor. + """ + raise NotImplementedError("raw_position_ids is not yet implemented") + + @property + def raw_rollout_routed_experts(self) -> torch.Tensor | None: + """Full (un-split) rollout_routed_experts across all SP ranks. + + Returns: + torch.Tensor | None: The full rollout_routed_experts tensor. + """ + raise NotImplementedError("raw_rollout_routed_experts is not yet implemented") + # TODO: 暂时没有用到,可能要删掉 def chunk(self, num_chunks: int) -> list[Self]: n = self.seq_lens_q.numel() @@ -374,6 +458,10 @@ def copy(self, **overrides) -> Self: inputs_embeds=overrides.get("inputs_embeds", self.inputs_embeds), num_img_tokens=overrides.get("num_img_tokens", self.num_img_tokens), rollout_routed_experts=overrides.get("rollout_routed_experts", self.rollout_routed_experts), + raw_input_ids=overrides.get("raw_input_ids", self._raw_input_ids), + raw_inputs_embeds=overrides.get("raw_inputs_embeds", self._raw_inputs_embeds), + shard_start=overrides.get("shard_start", self._shard_start), + shard_size=overrides.get("shard_size", self._shard_size), ) def to(self, device: torch.device | str): diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 20047e8a80..a3b68902c0 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -440,9 +440,12 @@ def build( ) elif self.group_by_length: assert shuffle, "Currently only shuffling is supported for LengthGroupedSampler." - assert isinstance(dataset, (ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset)), ( - "Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset or _LegacySoftPackDataset, " - f"but got {type(dataset)}" + assert isinstance( + dataset, + (ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset, MLLMPretrainHybridPackDataset), + ), ( + "Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset, _LegacySoftPackDataset, " + f"HardPackDataset, or MLLMPretrainHybridPackDataset, but got {type(dataset)}" ) sampler = LengthGroupedSampler( dataset=dataset, dp_mesh=dp_mesh, global_batch_size=global_batch_size, seed=seed diff --git a/xtuner/v1/datasets/packing.py b/xtuner/v1/datasets/packing.py index 9d3d8945e8..c672eee03b 100644 --- a/xtuner/v1/datasets/packing.py +++ b/xtuner/v1/datasets/packing.py @@ -4,11 +4,12 @@ import os import random import tempfile +from collections.abc import Sequence from concurrent.futures import ProcessPoolExecutor from functools import cached_property, partial from multiprocessing import shared_memory from pathlib import Path -from typing import Sized +from typing import Sized, cast import numpy as np import torch @@ -16,6 +17,7 @@ from datasets import Dataset, concatenate_datasets from torch import distributed as dist from torch.utils.data import ConcatDataset +from torch.utils.data import Dataset as TorchDataset from tqdm import tqdm from xtuner.v1.utils import get_logger, is_local_rank0 @@ -309,7 +311,7 @@ def get_pack_infos_by_expand_soft_split( class ExpandSoftPackDataset(_LegacySoftPackDataset): def __init__( self, - datasets: list[JsonlDataset], + datasets: Sequence[JsonlDataset], pack_max_length: int = 2048, global_pack: bool = False, pack_extra_buffer_size: int = 1000, @@ -642,7 +644,7 @@ def get_state_dict(self): def load_state_dict(self, state_dict): ... -class MLLMPretrainHybridPackDataset(_LegacySoftPackDataset): +class MLLMPretrainHybridPackDataset(TorchDataset): def __init__( self, datasets: list[JsonlDataset], @@ -653,17 +655,12 @@ def __init__( pack_extra_buffer_size: int = 1000, # for ExpandSoftPackDataset pack_chunk_size: int = 10000, # for ExpandSoftPackDataset ): - self.pack_extra_buffer_size = pack_extra_buffer_size - self.pack_workers = pack_workers - self.torch_random_generator = torch.Generator() - self.pack_chunk_size = pack_chunk_size - if seed is not None: - self.torch_random_generator.manual_seed(seed) - logger.info(f"Using {self.pack_workers} pack workers for packing datasets.") - self.seed = seed - self.global_pack = global_pack self.pack_max_length = pack_max_length + self.global_pack = global_pack + self.pack_workers = pack_workers + self.pack_extra_buffer_size = pack_extra_buffer_size + self.pack_chunk_size = pack_chunk_size hard_pack_groups = [] soft_pack_groups = [] @@ -673,100 +670,81 @@ def __init__( elif isinstance(dset, JsonlDataset): hard_pack_groups.append(dset) - if global_pack: - hard_pack_datasets: list[Sized] = [] - if len(hard_pack_groups) > 0: - num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in hard_pack_groups]))] - hard_pack_datasets = [ConcatDataset(hard_pack_groups)] - - pack_infos_list = [] - for i, dataset in enumerate(hard_pack_datasets): - _infos = self.get_hard_pack_infos(dataset, i, num_tokens[i]) - pack_infos_list.extend(_infos) - hard_pack_len = len(pack_infos_list) - - soft_pack_datasets: list[Sized] = [] - if len(soft_pack_groups) > 0: - num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in soft_pack_groups]))] - proxy_attn_flops = [ - ndarray_to_mmap(np.concatenate([dset.proxy_attn_flops for dset in soft_pack_groups])) - ] - - soft_pack_datasets = [ConcatDataset(soft_pack_groups)] - for i, dataset in enumerate(soft_pack_datasets): - _infos = self.get_soft_pack_infos(dataset, i, num_tokens[i], proxy_attn_flops[i]) - pack_infos_list.extend(_infos) - pack_infos = Dataset.from_list(pack_infos_list) + dataset_list: list[HardPackDataset | ExpandSoftPackDataset] = [] - else: - raise NotImplementedError + if hard_pack_groups: + hard_pack_dataset = HardPackDataset( + datasets=hard_pack_groups, + pack_max_length=pack_max_length, + global_pack=global_pack, + seed=seed, + pack_workers=pack_workers, + ) + dataset_list.append(hard_pack_dataset) - self.hard_pack_datasets = hard_pack_datasets - self.datasets = soft_pack_datasets - self.hard_pack_len = hard_pack_len - self.pack_infos = pack_infos + if soft_pack_groups: + soft_pack_dataset = ExpandSoftPackDataset( + datasets=soft_pack_groups, + pack_max_length=pack_max_length, + global_pack=global_pack, + pack_extra_buffer_size=pack_extra_buffer_size, + pack_chunk_size=pack_chunk_size, + pack_workers=pack_workers, + seed=seed, + ) + dataset_list.append(soft_pack_dataset) - def get_hard_pack_item(self, item: int): - info = self.pack_infos[item] - dataset_id = info["dataset_id"] - ds = self.hard_pack_datasets[dataset_id] + assert dataset_list, "No datasets provided for packing." + self.datasets: ConcatDataset[HardPackDataset | ExpandSoftPackDataset] = ConcatDataset(dataset_list) - indices = info["indices"] - s_off = info["start_offset"] - e_off = info["end_offset"] + @cached_property + def longest(self): + longest_list = [] + for dataset in self.datasets.datasets: + longest_list.extend(cast(HardPackDataset | ExpandSoftPackDataset, dataset).longest) + return longest_list - packed_list: list[dict] = [] + def __getitem__(self, item: int): + return self.datasets[item] - for i in range(len(indices)): - idx = indices[i] - sample = ds[idx] - ids = sample["input_ids"] - labs = sample.get("labels", None) + def __len__(self) -> int: + return len(self.datasets) - st = 0 if i != 0 else s_off - ed = len(ids) if i != len(indices) - 1 else e_off + def get_state_dict(self): + return { + "pack_max_length": self.pack_max_length, + "seed": self.seed, + "global_pack": self.global_pack, + "pack_extra_buffer_size": self.pack_extra_buffer_size, + "pack_chunk_size": self.pack_chunk_size, + } - packed_list.append( - { - "input_ids": ids[st:ed], - "labels": labs[st:ed] if labs is not None else None, - "num_tokens": ed - st, - } + def load_state_dict(self, state_dict): + if self.seed != state_dict["seed"]: + raise ValueError( + f"Cannot load state dict with different seed . Origin: {state_dict['seed']}, New: {self.seed}" ) - assert (total_num_tokens := sum(i["num_tokens"] for i in packed_list)) == self.pack_max_length, ( - f"Internal Error! Found size: {total_num_tokens} mismatch after hard packing." - ) - return packed_list - - def __getitem__(self, item: int): - if item < self.hard_pack_len: - return self.get_hard_pack_item(item) - else: - return super().__getitem__(item) - def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray): - # shuffled indices - inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist() + if self.pack_max_length != state_dict["pack_max_length"]: + raise ValueError( + "Cannot load state dict with different pack_max_length " + f". Origin: {state_dict['pack_max_length']}, New: {self.pack_max_length}" + ) - pack_infos_list = get_pack_infos_by_hard_split( - inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers - ) - return pack_infos_list + if self.global_pack != state_dict["global_pack"]: + raise ValueError( + "Cannot load state dict with different global_pack " + f". Origin: {state_dict['global_pack']}, New: {self.global_pack}" + ) - def get_soft_pack_infos( - self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray, proxy_attn_flops: np.ndarray - ): - # shuffled indices - inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist() + if self.pack_extra_buffer_size != state_dict["pack_extra_buffer_size"]: + raise ValueError( + "Cannot load state dict with different pack_extra_buffer_size " + f". Origin: {state_dict['pack_extra_buffer_size']}, New: {self.pack_extra_buffer_size}" + ) - pack_infos_list = get_pack_infos_by_expand_soft_split( - inds, - dataset_id, - num_tokens, - proxy_attn_flops, - pack_max_length=self.pack_max_length, - pack_workers=self.pack_workers, - pack_chunk_size=self.pack_chunk_size, - pack_extra_buffer_size=self.pack_extra_buffer_size, - ) - return pack_infos_list + if self.pack_chunk_size != state_dict["pack_chunk_size"]: + raise ValueError( + "Cannot load state dict with different pack_chunk_size " + f". Origin: {state_dict['pack_chunk_size']}, New: {self.pack_chunk_size}" + ) diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index daa1189c8d..946585b82d 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -12,7 +12,7 @@ from xtuner.v1.utils import get_logger from .jsonl import JsonlDataset -from .packing import _LegacySoftPackDataset +from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset logger = get_logger() @@ -49,7 +49,7 @@ class ParallelSampler(Sampler): def __init__( self, - dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset, + dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset | MLLMPretrainHybridPackDataset, global_batch_size: int, dp_mesh: DeviceMesh | None = None, shuffle: bool = True, @@ -173,7 +173,7 @@ class LengthGroupedSampler(Sampler): def __init__( self, - dataset: _LegacySoftPackDataset, + dataset: _LegacySoftPackDataset | MLLMPretrainHybridPackDataset, global_batch_size: int, dp_mesh: DeviceMesh | None = None, seed: Optional[int] = None, diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 2a576c59b7..f0fcfd1d88 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -172,7 +172,7 @@ def data_replicate_size(self) -> int: @torch.no_grad() def forward_only(self, seq_ctx: SequenceContext, loss_ctx: LogProbContext): - output = self.model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) + output = self.model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) # type: ignore[call-overload] return output def grad_accumulation_steps(self, data_batches_len: int): @@ -217,7 +217,7 @@ def train_step(self, data_batches: list[ModelItem]) -> TrainStepInfo: # Here we assume that the model can handle a list of seq_ctx and loss_ctx. output = self.model( seq_ctx=seq_ctx_list, - loss_ctx=loss_ctx_list, + loss_ctx=loss_ctx_list, # type: ignore[arg-type] ) output.free_nongrad_feature() @@ -421,6 +421,9 @@ def _get_total_loss(self, model_outputs: ModelOutputs) -> torch.Tensor: loss = torch.tensor(0.0, device=DEVICE) for key in model_outputs.model_fields: value = getattr(model_outputs, key) - if "loss" in key and isinstance(value, torch.Tensor): + if key == "mtp_loss": + for mtp_loss_name, mtp_loss in value.items(): + loss += mtp_loss + elif "loss" in key and isinstance(value, torch.Tensor): loss += value return loss diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index dfe3afc73d..4135e1f96b 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -1,19 +1,37 @@ from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs -from .ce_loss import CELossConfig, CELossContext +from .ce_loss import CELossConfig, CELossContext, LMHeadLossContext from .chunk_loss import ChunkLoss -from .moe_loss import BalancingLoss, ZLoss +from .moe_loss import ( + BalancingLoss, + BalancingLossConfig, + BalancingLossContext, + BalancingLossKwargs, + ZLoss, + ZLossConfig, + ZLossContext, + ZLossKwargs, +) +from .mtp_loss import MTPLossContext from .rl_loss import LogProbConfig, LogProbContext __all__ = [ "BalancingLoss", + "BalancingLossConfig", + "BalancingLossContext", + "BalancingLossKwargs", "ZLoss", + "ZLossConfig", + "ZLossContext", + "ZLossKwargs", "CELossContext", "CELossConfig", "ChunkLoss", "BaseLossConfig", "BaseLossContext", "BaseLossKwargs", + "LMHeadLossContext", + "MTPLossContext", "LogProbConfig", "LogProbContext", ] diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 8ae297b6f4..531a6dfdab 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -3,17 +3,10 @@ from typing import Annotated, Any, Literal, TypeVar import torch -import torch.distributed as dist import torch.nn as nn from cyclopts import Parameter from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.nn.functional import all_reduce -from typing_extensions import Self - -from xtuner.v1.loss.utils import sp_split - -from .chunk_loss import ChunkLoss # Do loss calibration among dp, sp and grad accumulation: @@ -46,18 +39,13 @@ class BaseLossKwargs(BaseModel): - """Everything needed to compute the loss.""" - - model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) - shifted_labels: torch.Tensor + """Everything needed to compute the loss. - def sp_split(self, sp_mesh: DeviceMesh) -> Self: - self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) - return self + Subclasses should implement sp_split() and to() methods if they contain tensors that need to be split across + sequence parallel mesh or moved to device. + """ - def to(self, device: torch.device | str) -> Self: - self.shifted_labels = self.shifted_labels.to(device) - return self + model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) def chunk(self, chunk_size) -> list["BaseLossKwargs"]: tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {} @@ -114,15 +102,35 @@ class BaseLossConfig(BaseModel): chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024 @property + @abstractmethod def loss_ctx_cls(self) -> type["BaseLossContext"]: raise NotImplementedError + # TODO: private property maybe not a good idea @property + @abstractmethod def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]: raise NotImplementedError - def build(self, *args, **kwargs) -> "BaseLossContext": - raise NotImplementedError + @abstractmethod + def build( + self, + data: dict, + sp_mesh: "DeviceMesh | None" = None, + ) -> "BaseLossContext | None": + """Build loss context from data dict. + + Subclasses should extract required fields from data dict and construct loss_kwargs. + + Args: + data (dict): Data dict containing all possible loss-related fields. + Different loss configs extract different fields as needed. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + BaseLossContext: Built loss context. + """ + ... # NOTE: Self type for BaseLossContext subclasses (F-bounded polymorphism) @@ -143,72 +151,10 @@ def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs): self._batch_size = 1 @staticmethod - @abstractmethod - def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: ... - - @abstractmethod - def loss_fn( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: - """Step 2.a and 2.b in the loss calculation.""" - ... - - def eager_mode( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ): - return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs) - - def chunk_mode( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ): - assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" - - chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) - loss, extra_info = ChunkLoss.apply( - hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size - ) - return loss, (None, extra_info) - - def forward( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: - from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo - - assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward" - if head_bias is not None: - raise NotImplementedError("Loss does not support head_bias yet.") - - if self.loss_cfg.mode == "eager": - loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) - else: - loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) - - # TODO: yanhuida, should be removed - if not isinstance(extra_info, ModelForwardExtraLogInfo): - extra_info = ModelForwardExtraLogInfo(extra_info) - - extra_info["local_base_loss"] = loss.detach().clone() - - # Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support - if dist.is_initialized(): - loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) - - return loss, (logits, extra_info) + def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: + for ctx in loss_ctx_list: + ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list @classmethod def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT: diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index 7d70a43e0c..01874fd617 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -1,20 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Annotated, Any, Literal, Sequence, cast +from typing import Annotated, Any, Literal, Sequence, cast, Optional import torch import torch.distributed as dist import torch.nn.functional as F from cyclopts import Parameter from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.nn.functional import all_reduce from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs -from xtuner.v1.utils.device import get_device +from xtuner.v1.loss.chunk_loss import ChunkLoss +from xtuner.v1.utils import ( + get_device, + get_logger, +) # from xtuner.v1.profiler.prober import ProberList from .utils import sp_gather, sp_split DEVICE = get_device() +logger = get_logger() class CELossConfig(BaseLossConfig): @@ -37,15 +43,35 @@ class CELossConfig(BaseLossConfig): def loss_ctx_cls(self) -> type["CELossContext"]: return CELossContext - def model_post_init(self, __context: Any) -> None: + @property + def _loss_kwargs_cls(self) -> type["CELossKwargs"]: + return CELossKwargs + + def model_post_init(self, _context: Any) -> None: if self.mode == "liger": assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction" def build( self, - shifted_labels: torch.Tensor, + data: dict, sp_mesh: DeviceMesh | None = None, - ) -> "CELossContext": + ) -> "CELossContext | None": + """Build CELossContext from data dict. + + Args: + data (dict): Data dict containing loss-related fields. + Required: shifted_labels + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + CELossContext | None: Built loss context. Returns None if shifted_labels + is not present in data dict. + """ + if "shifted_labels" not in data: + return None + # Extract required fields from data + shifted_labels = data["shifted_labels"] + loss_kwargs = CELossKwargs(shifted_labels=shifted_labels).to(DEVICE) if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) @@ -64,8 +90,16 @@ class CELossKwargs(BaseLossKwargs): shifted_labels: torch.Tensor loss_weight: torch.Tensor | None = None + def sp_split(self, sp_mesh: DeviceMesh) -> "CELossKwargs": + self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) + return self + + def to(self, device: torch.device | str) -> "CELossKwargs": + self.shifted_labels = self.shifted_labels.to(device) + return self -class CELossContext(BaseLossContext): + +class LMHeadLossContext(BaseLossContext): """Cross-entropy loss context for language models. Args: @@ -147,6 +181,7 @@ def build_batches( # type: ignore[override] for loss_ctx in loss_ctx_list: loss_ctx._batch_size = len(loss_ctx_list) + assert loss_ctx.loss_kwargs.loss_weight is not None loss_ctx.loss_kwargs.loss_weight /= global_denominator + 1e-12 return loss_ctx_list @@ -179,15 +214,30 @@ def loss_fn( return loss, (logits, {}) + def eager_mode( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None, + loss_kwargs: CELossKwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs) + def chunk_mode( self, hidden_states: torch.Tensor, head_weight: torch.Tensor, head_bias: torch.Tensor | None, loss_kwargs: CELossKwargs, - ): + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: if self.loss_cfg.mode == "chunk": - return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs) + assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" + + chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) + loss, extra_info = ChunkLoss.apply( + hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size + ) + return loss, (None, extra_info) else: assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode" shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len) @@ -209,3 +259,136 @@ def chunk_mode( @property def batch_size(self) -> int: return self._batch_size + + def forward( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None = None, + mtp_config = None, + layer_idx: int = 0, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo + + assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward" + if head_bias is not None: + raise NotImplementedError("Loss does not support head_bias yet.") + + if mtp_config is not None: + if mtp_config.mask_type is None: + pass + elif mtp_config.mask_type == "v1": + self.process_loss_weights_v1(mtp_config, layer_idx) + elif mtp_config.mask_type == "v2": + assert layer_idx == 0, layer_idx + self.process_loss_weights_v2(mtp_config) + elif mtp_config.mask_type == "v3": + self.process_loss_weights_v3(mtp_config) + else: + raise NotImplementedError(mtp_config.mask_type) + + if self.loss_cfg.mode == "eager": + loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + else: + loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + + # TODO: yanhuida, should be removed + if not isinstance(extra_info, ModelForwardExtraLogInfo): + extra_info = ModelForwardExtraLogInfo(extra_info) + + extra_info["local_base_loss"] = loss.detach().clone() + + # Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support + if dist.is_initialized(): + loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) + + return loss, (logits, extra_info) + + def process_loss_weights_v1(self, mtp_config, layer_idx): + shifted_labels = self.loss_kwargs.shifted_labels + loss_weight = self.loss_kwargs.loss_weight + sum_loss_weight = loss_weight.sum() + mtp_mask = torch.zeros_like(shifted_labels) + bsz, seq_len = shifted_labels.shape + inside_mtp_zone = False + assert bsz == 1, shifted_labels.shape + for j in range(seq_len): + token = shifted_labels[0, j].item() + + if token in mtp_config.open_token_list: + inside_mtp_zone = True + + if inside_mtp_zone: + flag = True + for i in range(layer_idx + 1 + 2): + token_id = shifted_labels[0, j - i].item() + token_in_list = token_id in mtp_config.open_token_list or token_id in mtp_config.close_token_list + flag = flag and not token_in_list + + if flag: + mtp_mask[0, j] = 1.0 + + if token in mtp_config.close_token_list: + inside_mtp_zone = False + loss_weight[mtp_mask == 0.0] = 0.0 + if loss_weight.sum().item() != 0: + loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() + + self.loss_kwargs.loss_weight = loss_weight + + def process_loss_weights_v2(self, mtp_config): + + shifted_labels = self.loss_kwargs.shifted_labels + loss_weight = self.loss_kwargs.loss_weight + sum_loss_weight = loss_weight.sum() + + easy_to_use = torch.cat([ + shifted_labels, + torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device) + ], dim=-1) + + is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) + is_dot = torch.where(easy_to_use == 13, 1, 0) + is_digit_or_dot = is_digit + is_dot + + mixed = is_digit_or_dot - torch.roll(is_digit_or_dot, shifts=1, dims=-1) + + left = mixed > 0 + right_unincluded = mixed < 0 + left = torch.roll(left, shifts=1, dims=-1) + need_cumsum = torch.where(left, 1, 0) + torch.where(right_unincluded, -1, 0) + + mtp_mask = torch.cumsum(need_cumsum, dim=-1).bool()[:, :-1] + + loss_weight[mtp_mask == 0.0] = 0.0 + if loss_weight.sum().item() != 0: + loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() + + self.loss_kwargs.loss_weight = loss_weight + + def process_loss_weights_v3(self, mtp_config): + + shifted_labels = self.loss_kwargs.shifted_labels + loss_weight = self.loss_kwargs.loss_weight + sum_loss_weight = loss_weight.sum() + + easy_to_use = torch.cat([ + shifted_labels, + torch.zeros((shifted_labels.size(0), 1), dtype=shifted_labels.dtype, device=shifted_labels.device) + ], dim=-1) + + is_digit = torch.where(easy_to_use < 25, easy_to_use > 14, 0) + is_dot = torch.where(easy_to_use == 13, 1, 0) + is_digit_or_dot = is_digit | is_dot + + mask = is_digit_or_dot | torch.roll(is_digit_or_dot, shifts=1, dims=-1) + mtp_mask = mask.bool()[:, :-1] + + loss_weight[mtp_mask == 0.0] = 0.0 + if loss_weight.sum().item() != 0: + loss_weight = loss_weight * sum_loss_weight / loss_weight.sum() + + self.loss_kwargs.loss_weight = loss_weight + +# Deprecated: Use LMHeadLossContext instead. Will be removed in version 1.1.0 +CELossContext = LMHeadLossContext diff --git a/xtuner/v1/loss/moe_loss.py b/xtuner/v1/loss/moe_loss.py index d5aa64263b..b6cc8ccfb4 100644 --- a/xtuner/v1/loss/moe_loss.py +++ b/xtuner/v1/loss/moe_loss.py @@ -1,10 +1,17 @@ -from typing import Literal +from typing import Annotated, Literal import torch import torch.nn as nn +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict from torch import distributed as dist from torch.distributed._functional_collectives import all_reduce +from xtuner.v1.utils.device import get_device + + +DEVICE = get_device() + class _AllReduce(torch.autograd.Function): @staticmethod @@ -106,3 +113,232 @@ def forward(self, router_logits): return torch.tensor(0.0, device=router_logits.device, dtype=torch.float32) loss = z_loss(router_logits, self.global_average) return loss * self.loss_weight + + +# ==================== New LossContext-based implementation ==================== + + +class BalancingLossConfig(BaseModel): + """Balancing loss configuration for MoE models. + + Args: + balancing_loss_alpha (float): Weight for the balancing loss. Defaults to 0.001. + balancing_loss_global_average (bool): Whether to perform global averaging across all ranks. + Defaults to True. + router_scoring_func (str): Router scoring function type. Options are "sigmoid" and "softmax". + Defaults to "softmax". + """ + + model_config = ConfigDict(extra="forbid") + balancing_loss_alpha: Annotated[float, Parameter(help="weight for balancing loss")] = 0.001 + balancing_loss_global_average: Annotated[bool, Parameter(help="global average for balancing loss")] = True + router_scoring_func: Annotated[Literal["sigmoid", "softmax"], Parameter(help="router scoring function")] = ( + "softmax" + ) + + def build(self) -> "BalancingLossContext": + """Build BalancingLossContext. + + Returns: + BalancingLossContext: Built loss context. + """ + loss_kwargs = BalancingLossKwargs() + return BalancingLossContext(self, loss_kwargs) + + +class BalancingLossKwargs(BaseModel): + """Keyword arguments for balancing loss computation. + + This class is empty as all parameters are passed to forward(). + """ + + model_config = ConfigDict(title="balancing loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) + + +class BalancingLossContext(nn.Module): + """Balancing loss context for MoE models. + + Args: + loss_cfg (BalancingLossConfig): The configuration for the balancing loss. + loss_kwargs (BalancingLossKwargs): The keyword arguments for the balancing loss. + """ + + def __init__(self, loss_cfg: BalancingLossConfig, loss_kwargs: BalancingLossKwargs): + super().__init__() + self.loss_cfg = loss_cfg + self.loss_kwargs = loss_kwargs + self._batch_size = 1 + + @staticmethod + def build_batches( + loss_ctx_list: list["BalancingLossContext"], + ) -> list["BalancingLossContext"]: + """Build batches for balancing loss contexts. + + For balancing loss, we set the batch size for proper gradient accumulation. + + Args: + loss_ctx_list (list[BalancingLossContext]): List of loss contexts. + + Returns: + list[BalancingLossContext]: The same list with batch_size set. + """ + for loss_ctx in loss_ctx_list: + loss_ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list + + def forward( + self, + router_weights: torch.Tensor, + n_routed_experts: int, + num_experts_per_tok: int, + ) -> torch.Tensor: + """Compute balancing loss. + + Args: + router_weights (torch.Tensor): Router weights. Shape: (num_layers, seq_len, num_experts). + n_routed_experts (int): Number of routed experts. + num_experts_per_tok (int): Number of experts per token. + + Returns: + torch.Tensor: Balancing loss value. + """ + if self.loss_cfg.balancing_loss_alpha == 0: + return torch.tensor(0.0, device=router_weights.device, dtype=torch.float32) + + num_layers = router_weights.shape[0] + router_weights = router_weights.float() # (nlayers, seq, ne) + _, selected_experts = torch.topk(router_weights, num_experts_per_tok, dim=-1) + selected_experts_flat = selected_experts.view(num_layers, -1) + offset = torch.arange(num_layers, device=router_weights.device).unsqueeze(1) * n_routed_experts + selected_experts_offset = selected_experts_flat + offset + tokens_per_expert_flat = torch.histc( + selected_experts_offset.view(-1), + bins=num_layers * n_routed_experts, + min=0, + max=num_layers * n_routed_experts, + ) + tokens_per_expert = tokens_per_expert_flat.view(num_layers, n_routed_experts) # (nlayers, ne) + + tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne) + if self.loss_cfg.balancing_loss_global_average and dist.is_initialized(): + tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # type: ignore + tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, ) + seqlen_global = tokens_global // num_experts_per_tok + routing_weights_sum_global = all_reduce_autograd(router_weights.sum(dim=1), "sum", dist.group.WORLD) + routing_weights_mean_global = routing_weights_sum_global / seqlen_global.unsqueeze(-1) + scale_global = n_routed_experts / tokens_global + else: + scale_global = n_routed_experts / (router_weights.shape[1] * num_experts_per_tok) + routing_weights_mean_global = router_weights.mean(dim=1) + + loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1) + loss = loss.sum() * self.loss_cfg.balancing_loss_alpha + + # Normalize by batch size for proper gradient accumulation + loss = loss / self._batch_size + + return loss + + @property + def batch_size(self) -> int: + return self._batch_size + + +class ZLossConfig(BaseModel): + """Z-loss configuration for MoE models. + + Args: + z_loss_alpha (float): Weight for the z-loss. Defaults to 0.001. + z_loss_global_average (bool): Whether to perform global averaging across all ranks. + Defaults to True. + """ + + model_config = ConfigDict(extra="forbid") + z_loss_alpha: Annotated[float, Parameter(help="weight for z-loss")] = 0.001 + z_loss_global_average: Annotated[bool, Parameter(help="global average for z-loss")] = True + + def build(self) -> "ZLossContext": + """Build ZLossContext. + + Returns: + ZLossContext: Built loss context. + """ + loss_kwargs = ZLossKwargs() + return ZLossContext(self, loss_kwargs) + + +class ZLossKwargs(BaseModel): + """Keyword arguments for z-loss computation.""" + + model_config = ConfigDict(title="z-loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) + + +class ZLossContext(nn.Module): + """Z-loss context for MoE models. + + Args: + loss_cfg (ZLossConfig): The configuration for the z-loss. + loss_kwargs (ZLossKwargs): The keyword arguments for the z-loss. + """ + + def __init__(self, loss_cfg: ZLossConfig, loss_kwargs: ZLossKwargs): + super().__init__() + self.loss_cfg = loss_cfg + self.loss_kwargs = loss_kwargs + self._batch_size = 1 + + @staticmethod + def build_batches( + loss_ctx_list: list["ZLossContext"], + ) -> list["ZLossContext"]: + """Build batches for z-loss contexts. + + For z-loss, we set the batch size for proper gradient accumulation. + + Args: + loss_ctx_list (list[ZLossContext]): List of loss contexts. + + Returns: + list[ZLossContext]: The same list with batch_size set. + """ + for loss_ctx in loss_ctx_list: + loss_ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list + + def forward(self, router_logits: torch.Tensor) -> torch.Tensor: + """Compute z-loss. + + Args: + router_logits (torch.Tensor): Router logits. Shape: (num_layers, seq_len, num_experts). + + Returns: + torch.Tensor: Z-loss value. + """ + if self.loss_cfg.z_loss_alpha == 0: + return torch.tensor(0.0, device=router_logits.device, dtype=torch.float32) + + router_logits = router_logits.float() # (nlayers, seq, ne) + num_seq = max(1, router_logits.shape[1]) + logsum_square = torch.logsumexp(router_logits, dim=-1).square() + loss = (logsum_square.sum(dim=-1) / num_seq).sum() + + if self.loss_cfg.z_loss_global_average and dist.is_initialized(): + unmasked_num = router_logits.shape[1] + unmasked_num_rank = torch.tensor(unmasked_num, device=router_logits.device, dtype=torch.int64) + group = dist.group.WORLD + assert group is not None + unmasked_num_global = all_reduce(unmasked_num_rank, "sum", group) + world_size = dist.get_world_size() + loss = loss * unmasked_num * world_size / unmasked_num_global + + loss = loss * self.loss_cfg.z_loss_alpha + + # Normalize by batch size for proper gradient accumulation + loss = loss / self._batch_size + + return loss + + @property + def batch_size(self) -> int: + return self._batch_size diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py new file mode 100644 index 0000000000..ff55c603f4 --- /dev/null +++ b/xtuner/v1/loss/mtp_loss.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.distributed.device_mesh import DeviceMesh + +from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext +from xtuner.v1.utils.device import get_device + + +DEVICE = get_device() + + +class MTPLossKwargs(CELossKwargs): + """Keyword arguments for MTP loss computation. + + Inherits all fields from CELossKwargs. The ``shifted_labels`` field is + expected to be pre-rolled by ``MTPLossConfig.build()`` before this object + is constructed, so no additional fields are required. + + Args: + shifted_labels (torch.Tensor): The shifted and rolled labels for MTP + loss computation. + loss_weight (torch.Tensor | None): Per-token loss weight. + """ + + +class MTPLossConfig(CELossConfig): + """Loss configuration for Multi-Token Prediction (MTP). + + Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many + additional positions the labels are rolled during ``build()``. This class + is intended for internal use by the model and is not exposed to users. + + Args: + mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses + ``mtp_depth=1`` (shift=-1 on top of the existing label shift). + """ + + mtp_depth: int + + @property + def loss_ctx_cls(self) -> type["MTPLossContext"]: + return MTPLossContext + + @property + def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: + return MTPLossKwargs + + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": + """Build MTPLossContext from data dict. + + Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence, + respecting packed-sequence boundaries) before constructing the loss + context. The roll is performed on the full sequence prior to any + sequence-parallel split so that boundary positions and ``cu_seq_lens`` + are always consistent. + + Args: + data (dict): Data dict containing loss-related fields. + Required keys: ``shifted_labels``, ``seq_ctx``. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + MTPLossContext | None: Built loss context, or ``None`` if + ``shifted_labels`` is not present in ``data``. + """ + # TODO: Should move the common utils function to public package to avoid from circular import. + from xtuner.v1.module.mtp.utils import roll_packed_tensor + + if "shifted_labels" not in data: + return None + + shifted_labels = data["shifted_labels"] + cu_seq_lens = data["seq_ctx"].cu_seq_lens_k + + # cu_seq_lens[-1] may be larger than shifted_labels.shape[-1] when seq_ctx + # was split for sequence parallelism (padding is added to make the sequence + # length a multiple of sp_size). Pad with -100 so roll_packed_tensor does + # not go out of bounds. + padded_len = int(cu_seq_lens[-1].item()) + seq_len = shifted_labels.shape[-1] + if padded_len > seq_len: + pad = torch.full( + (*shifted_labels.shape[:-1], padded_len - seq_len), + fill_value=-100, + dtype=shifted_labels.dtype, + device=shifted_labels.device, + ) + shifted_labels = torch.cat([shifted_labels, pad], dim=-1) + + rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100) + + loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE) + if sp_mesh is not None and sp_mesh.size() > 1: + loss_kwargs = loss_kwargs.sp_split(sp_mesh) + + return MTPLossContext(self, loss_kwargs) + + +class MTPLossContext(LMHeadLossContext): + """Loss context for Multi-Token Prediction (MTP). + + Inherits all computation logic from ``LMHeadLossContext``. The label + rolling is handled upstream in ``MTPLossConfig.build()``, so no override + is needed here. + + Args: + loss_cfg (MTPLossConfig): The MTP loss configuration. + loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss + computation. + """ diff --git a/xtuner/v1/loss/rl_loss.py b/xtuner/v1/loss/rl_loss.py index 4193e95970..66e5352770 100644 --- a/xtuner/v1/loss/rl_loss.py +++ b/xtuner/v1/loss/rl_loss.py @@ -18,8 +18,10 @@ class LogProbConfig(BaseLossConfig): def loss_ctx_cls(self) -> type["LogProbContext"]: return LogProbContext - def build(self, shifted_labels: torch.Tensor, sp_mesh: DeviceMesh | None = None) -> "LogProbContext": - loss_kwargs = LogProbKwargs(shifted_labels=shifted_labels) + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "LogProbContext | None": + if "shifted_labels" not in data: + return None + loss_kwargs = LogProbKwargs(shifted_labels=data["shifted_labels"]) if sp_mesh is not None and sp_mesh.size() > 1: loss_kwargs = loss_kwargs.sp_split(sp_mesh) return self.loss_ctx_cls(self, loss_kwargs) @@ -83,5 +85,5 @@ def forward( if self.loss_cfg.mode == "chunk": logprobs, _ = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) else: - logprobs, _ = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + logprobs, _ = self.loss_fn(hidden_states, head_weight, head_bias, self.loss_kwargs) return logprobs, (None, {}) diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index b0744864f6..0a29a4530c 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -23,10 +23,9 @@ from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig from .moe.deepseek_v3 import DeepSeekV3Config from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig -from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig +from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig - model_mapping = { "qwen3-moe-30BA3": Qwen3MoE30BA3Config(), "qwen3-8B": Qwen3Dense8BConfig(), @@ -87,6 +86,7 @@ def get_model_config_from_hf(model_path: Path): "get_model_config", "get_model_config_from_hf", "MoE", + "MoEConfig", "MoEModelOutputs", "BalancingLossConfig", "ZLossConfig", diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 071cd4ebc8..0dba8080e2 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -8,7 +8,7 @@ from itertools import chain from pathlib import Path from shutil import copy, copytree -from typing import Annotated, Generator, Iterable, Literal, Mapping, Sequence, cast +from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, cast import torch import torch.distributed as dist @@ -39,7 +39,7 @@ WeightWithDynamicTensorWiseFloat8CastTensor, WeightWithDynamicTilewiseFloat8CastTensor, ) -from xtuner.v1.loss import BaseLossContext +from xtuner.v1.loss import BaseLossConfig, BaseLossContext, CELossConfig from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather @@ -106,6 +106,7 @@ class XTunerBaseModelConfig(PydanticBaseModel): ] = None hf_key_mapping: Annotated[dict[str, str] | None, "Remapping hf key based on the `to_hf_key_list`"] = None dcp_ignore_frozen_params: bool = True + lm_loss_cfg: BaseLossConfig = CELossConfig() @property def hf_config(self) -> PretrainedConfig | None: @@ -250,7 +251,7 @@ def _is_float8_available(): class ModelItem(TypedDict): seq_ctx: SequenceContext - loss_ctx: BaseLossContext + loss_ctx: dict[str, BaseLossContext] | None def is_float8_weight(tensor): @@ -673,6 +674,82 @@ def _to_float8( name_list_new.extend([name, f"{name}_scale_inv"]) return gathered_tensor_list_new, name_list_new + def build_loss_ctx_batch( + self, + data_batch: list[dict], + sp_mesh: DeviceMesh | None = None, + ) -> list[dict[str, dict]]: + """Build and calibrate loss contexts for the entire batch. + + For Dense model, only LM loss is needed. + + Args: + data_batch (list[dict]): All microbatch data + sp_mesh (DeviceMesh | None): Sequence parallel mesh + cu_seq_lens_list (list[torch.IntTensor] | None): For calibration + + Returns: + list[dict[str, BaseLossContext]]: Loss context dict for each microbatch + """ + cu_seq_lens_list = [data["seq_ctx"].cu_seq_lens_k for data in data_batch] + res: list[dict] = [{} for _ in range(len(data_batch))] + + lm_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, data_batch, sp_mesh) + + if lm_loss_ctx_list is not None: + loss_ctx_cls = lm_loss_ctx_list[0].__class__ + lm_loss_ctx_list = loss_ctx_cls.build_batches( + lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh + ) + + if lm_loss_ctx_list is not None: + for i, lm_loss_ctx in enumerate(lm_loss_ctx_list): + res[i]["lm"] = lm_loss_ctx + + return res + + def _add_auxiliary_loss( + self, + loss_name: str, + loss_cfg: Any, + data_batch: list[dict], + res: list[dict], + ) -> None: + """Add auxiliary loss contexts to result. + + This helper builds loss contexts, calibrates them across the batch, + and adds them to the result dictionary. If loss_cfg is None, does nothing. + + Args: + loss_name (str): Name of the loss (e.g., "balancing", "z_loss"). + loss_cfg (Any): Loss configuration with a build() method. If None, skipped. + data_batch (list[dict]): Batch data. + res (list[dict]): Result dictionary to populate. Modified in-place. + + Example: + def build_loss_ctx_batch(self, data_batch, sp_mesh): + res = super().build_loss_ctx_batch(data_batch, sp_mesh) + + # One line per auxiliary loss + self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, data_batch, res) + self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, data_batch, res) + + return res + """ + if loss_cfg is None: + return + + # Build loss contexts for all microbatches + ctx_list = [loss_cfg.build() for _ in data_batch] + + # Calibrate across batch + ctx_cls = ctx_list[0].__class__ + ctx_list = ctx_cls.build_batches(ctx_list) + + # Add to result + for i, ctx in enumerate(ctx_list): + res[i][loss_name] = ctx # type: ignore + def pre_micro_batch_forward(self, data_batches: Sequence[ModelItem]) -> DataBatchInfo: step_consumed_tokens = torch.tensor(0, device=DEVICE) step_consumed_img_tokens = torch.tensor(0.0, device=DEVICE) @@ -719,7 +796,18 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat output_copy = output.model_copy() for name in output_copy.model_fields: obj = getattr(output_copy, name) - if "loss" in name and isinstance(obj, torch.Tensor): + if name == "mtp_loss": + for key, value in obj.items(): + loss_item = value.item() + local_total_loss += loss_item + reduced_name = f"{key}_reduced_mtp_loss" + + if reduced_name not in reduced_other_losses: + reduced_other_losses[reduced_name] = loss_item + else: + reduced_other_losses[reduced_name] += loss_item + + elif "loss" in name and isinstance(obj, torch.Tensor): loss_item = obj.item() local_total_loss += loss_item reduced_name = f"reduced_{name}" @@ -1445,6 +1533,9 @@ def _load_fused_hf_param( continue _loaded_tensor.append(weight.to(local_tensor.device)) + if not _loaded_tensor: + return missing_keys + if not hf_keys: # fp8 pad assert self.config.float8_cfg is not None @@ -1784,19 +1875,34 @@ def _collect_full_state_dict(self, module: nn.Module): ret[name] = param return ret + def _build_loss_ctx( + self, loss_ctx_cfg: BaseLossConfig | None, data_batch: list[dict], sp_mesh: DeviceMesh | None + ) -> list[BaseLossContext] | None: + if loss_ctx_cfg is None: + return None + + first_loss_ctx = loss_ctx_cfg.build(data=data_batch[0], sp_mesh=sp_mesh) + # If first build returns None, assume all data in the batch have the same schema + # and will also return None (e.g., missing required fields like shifted_labels) + if first_loss_ctx is None: + return None + else: + ret = [first_loss_ctx] + [loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]] + return ret # type: ignore[return-value] + # NOTE: Add this overload for inferring the return type for easier type checking and using @overload # type: ignore def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: BaseLossContext | None, + loss_ctx: dict[str, BaseLossContext] | None, ) -> ModelOutputs: ... @overload # type: ignore def __call__( # type: ignore self, seq_ctx: list[SequenceContext], - loss_ctx: list[BaseLossContext], + loss_ctx: list[dict[str, BaseLossContext]], ) -> ModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/model/compose/base.py b/xtuner/v1/model/compose/base.py index a2cc24e518..ccccc7cd8e 100644 --- a/xtuner/v1/model/compose/base.py +++ b/xtuner/v1/model/compose/base.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist from pydantic import ConfigDict -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.fsdp import ( CPUOffloadPolicy, FSDPModule, @@ -14,6 +14,7 @@ from typing_extensions import override from xtuner.v1.config import FSDPConfig +from xtuner.v1.loss import BaseLossContext from xtuner.v1.model import BaseModel from xtuner.v1.model.base import XTunerBaseModelConfig from xtuner.v1.utils import get_device, get_logger @@ -162,3 +163,13 @@ def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, def scale_and_reduce_grad(self): self.language_model.scale_and_reduce_grad() + + @override + def build_loss_ctx_batch( # type: ignore[override] + self, + data_batch: list[dict], + sp_mesh: DeviceMesh | None = None, + ) -> list[dict[str, BaseLossContext]]: + """Delegate loss_ctx building to the language model.""" + # TODO: Maybe we need to consider the `loss_ctx` of vision model. + return self.language_model.build_loss_ctx_batch(data_batch, sp_mesh=sp_mesh) diff --git a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py index 0f4563896d..7a7c4387c6 100644 --- a/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py +++ b/xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py @@ -121,7 +121,7 @@ def extract_feature(self, pixel_values): def forward( self, seq_ctx: SequenceContext, - loss_ctx: CELossContext + loss_ctx: dict[str, CELossContext] | None = None ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids pixel_values = seq_ctx.pixel_values diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py index 884cd48b15..5bded9d903 100644 --- a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -1,4 +1,4 @@ -from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.moe.moe import MoEConfig from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig from xtuner.v1.utils import get_logger @@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): vision_config: Qwen3_5_VisionConfig projector_config: Qwen3_5_ProjectorConfig - text_config: TransformerConfig + text_config: MoEConfig image_token_id: int = 248056 video_token_id: int = 248057 @@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() - text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig() + text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig() diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 9c2865f593..8ac4617d2c 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -141,7 +141,7 @@ def get_placeholder_mask( def forward( self, seq_ctx: SequenceContext, - loss_ctx: CELossContext + loss_ctx: dict[str, CELossContext] | None = None ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids pixel_values = seq_ctx.pixel_values diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 36c9465a86..35c16db602 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -19,7 +19,7 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import BaseLossContext, CELossContext from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -78,7 +78,7 @@ def __init__(self, config: TransformerConfig): def forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext, + loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None, ) -> ModelOutputs: input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -110,10 +110,17 @@ def forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) - output["loss"] = loss - output["logits"] = logits - output["extra_info"] = extra_info + if loss_ctx is None: + # Inference mode + logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias) + output["logits"] = logits + else: + # Training mode + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload] + output["loss"] = loss + output["logits"] = logits + output["extra_info"] = extra_info + return ModelOutputs(**output) def build_embeddings(self, config: TransformerConfig): @@ -166,7 +173,7 @@ def default_compile_cfg(self) -> dict[str, TorchCompileOption]: def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: CELossContext, + loss_ctx: dict[str, CELossContext] | None = None, ) -> ModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/model/dense/qwen3vl_text.py b/xtuner/v1/model/dense/qwen3vl_text.py index 15e40c40b0..308a9817d8 100644 --- a/xtuner/v1/model/dense/qwen3vl_text.py +++ b/xtuner/v1/model/dense/qwen3vl_text.py @@ -1,9 +1,10 @@ import re import torch +import torch.nn.functional as F from xtuner.v1.data_proto import SequenceContext -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import BaseLossContext from xtuner.v1.model.base import ModelOutputs from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig @@ -34,10 +35,10 @@ def _deepstack_process( hidden_states[visual_pos_masks, :] = local_this return hidden_states - def forward( + def forward( # type: ignore[override] self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext, + loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None, ) -> ModelOutputs: input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -78,11 +79,18 @@ def forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) - output["loss"] = loss - output["logits"] = logits - output["extra_info"] = extra_info - return ModelOutputs(**output) # type: ignore[typeddict-item] + if loss_ctx is None: + # Inference mode + logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias) + output["logits"] = logits + else: + # Training mode + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload] + output["loss"] = loss + output["logits"] = logits + output["extra_info"] = extra_info + + return ModelOutputs(**output) class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig): diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 29f5bc4a5f..6ae72aae69 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -2,13 +2,12 @@ import os import types from pathlib import Path -from typing import Annotated, Literal, Self, Sequence, cast +from typing import TYPE_CHECKING, Annotated, Literal, Self, Sequence, TypedDict, cast, List import torch import torch.distributed as dist import torch.nn.functional as F from cyclopts import Parameter -from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict from torch import nn from torch.distributed._functional_collectives import all_reduce @@ -26,7 +25,16 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.loss import BalancingLoss, CELossContext, ZLoss +from xtuner.v1.loss import ( + BalancingLossConfig, + BalancingLossContext, + BaseLossContext, + LMHeadLossContext, + MTPLossContext, + ZLossConfig, + ZLossContext, +) +from xtuner.v1.loss.mtp_loss import MTPLossConfig from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -50,6 +58,7 @@ ) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer +from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer from xtuner.v1.utils import ( get_device, get_logger, @@ -57,6 +66,10 @@ from xtuner.v1.utils.activation_offload import async_save_on_cpu +if TYPE_CHECKING: + from xtuner.v1.datasets.collator import ColateItem + + DEVICE = get_device() logger = get_logger() @@ -87,6 +100,7 @@ class MoEModelOutputs(ModelOutputs): balancing_loss: torch.Tensor | None = None z_loss: torch.Tensor | None = None tokens_per_expert_global: torch.Tensor + mtp_loss: dict[str, torch.Tensor] | None = None def free_nongrad_feature(self): """Release large intermediate tensors not needed for backward or @@ -109,31 +123,11 @@ class MoEBatchForwardInfo(BatchForwardInfo): maxvio: float -class BalancingLossConfig(PydanticBaseModel): - model_config = ConfigDict(extra="forbid") - balancing_loss_alpha: float = 0.001 - balancing_loss_global_average: bool = True - - def build(self, router_scoring_func) -> BalancingLoss: - return BalancingLoss( - self.balancing_loss_alpha, - self.balancing_loss_global_average, - router_scoring_func=router_scoring_func, - ) - - -class ZLossConfig(PydanticBaseModel): - model_config = ConfigDict(extra="forbid") - z_loss_alpha: float = 0.001 - z_loss_global_average: bool = True - - def build(self) -> "ZLoss": - from xtuner.v1.loss import ZLoss - - return ZLoss( - self.z_loss_alpha, - self.z_loss_global_average, - ) +class MoELossContextDict(TypedDict): + lm: BaseLossContext + balancing: BalancingLossContext | None + z_loss: ZLossContext | None + mtp: list[BaseLossContext] | None class MoEConfig(TransformerConfig): @@ -154,6 +148,7 @@ class MoEConfig(TransformerConfig): gate_bias: bool = False moe_bias: bool = False moe_act_fn_cfg: MoEActFnConfig = MoEActFnConfig() + mtp_config: List[MTPConfig] | None = None freeze_routers: bool = False def build(self) -> "MoE": @@ -197,6 +192,7 @@ def __init__(self, config: MoEConfig): self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) self.embed_tokens = self.build_embeddings(config) + self.mtp_block = self.build_mtp_block_dict(config) if config.mtp_config is not None else None self.fp32_layers = [self.rotary_emb] @@ -205,17 +201,6 @@ def __init__(self, config: MoEConfig): self._init_load_spec() self._maybe_enable_compile(self.compile_cfg) - self.balancing_loss: BalancingLoss | None - self.z_loss: ZLoss | None - if self.config.balancing_loss_cfg is not None: - self.balancing_loss = self.config.balancing_loss_cfg.build(self.config.router.scoring_func) - else: - self.balancing_loss = None - if self.config.z_loss_cfg is not None: - self.z_loss = self.config.z_loss_cfg.build() - else: - self.z_loss = None - self.offload_stream = torch.cuda.Stream() def _select_non_pad_router_logits( @@ -304,10 +289,80 @@ def update_bias(self, total_expert_counts_pre_iter, expected_loads): e_score_correction_bias.add_(updates) + def build_loss_ctx_batch( # type: ignore[override] + self, + data_batch: list["ColateItem"], + sp_mesh: DeviceMesh | None = None, + ) -> list[MoELossContextDict]: # type: ignore[override] + """Build and calibrate loss contexts for MoE model. + + Args: + data_batch (list[dict]): All microbatch data + sp_mesh (DeviceMesh | None): Sequence parallel mesh + cu_seq_lens_list (list[torch.IntTensor] | None): For calibration + + Returns: + list[dict]: Loss context dict for each microbatch. + Each dict contains: + - "lm": LM loss context + - "balancing": Balancing loss context (if configured) + - "z_loss": Z-loss context (if configured) + - "mtp": MTP loss contexts (if configured) + + Note: + Auxiliary loss contexts are built without parameters. + All data is passed to forward() at runtime: + - balancing_ctx(router_weights, n_routed_experts, num_experts_per_tok) + - z_loss_ctx(router_logits) + """ + # Build LM loss context + _data_batch: list[dict] = data_batch # type: ignore[assignment] + res: list[dict] = super().build_loss_ctx_batch(_data_batch, sp_mesh) + cu_seq_lens_list = [data["seq_ctx"].cu_seq_lens_k for data in data_batch] + + # Add auxiliary losses + self._add_auxiliary_loss("balancing", self.config.balancing_loss_cfg, _data_batch, res) + self._add_auxiliary_loss("z_loss", self.config.z_loss_cfg, _data_batch, res) + + # Add MTP loss contexts if MTP is enabled + if self.config.mtp_config is not None: + # Build MTP loss contexts using the same approach as LM loss + # Each MTP depth needs its own loss context + for mtp_config in self.config.mtp_config: + for mtp_idx in range(mtp_config.num_layers): + mtp_loss_cfg = MTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + mtp_depth=mtp_idx + 1, + ) + # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch + mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) + if mtp_loss_ctx_list is not None: + mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] + cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] + cu_seq_lens_list=cu_seq_lens_list, + sp_mesh=sp_mesh, + ) + for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): + if "mtp" not in res[i]: + res[i]["mtp"] = {} + if mtp_config.name not in res[i]["mtp"]: + res[i]["mtp"][mtp_config.name] = [] + res[i]["mtp"][mtp_config.name].append(mtp_loss_ctx) # type: ignore[union-attr] + + # Ensure all microbatches have mtp key + for loss_ctx_dict in res: + if "mtp" not in loss_ctx_dict: + loss_ctx_dict["mtp"] = None + else: + for loss_ctx_dict in res: + loss_ctx_dict["mtp"] = None + + return res # type: ignore[return-value] + def forward( self, seq_ctx: list[SequenceContext] | SequenceContext, - loss_ctx: list[CELossContext] | CELossContext | None, + loss_ctx: list[MoELossContextDict] | MoELossContextDict | None, return_router_logits: bool = False, ): # TODO: caoweihan: Recover this assertion after the refactor of LossContext @@ -327,6 +382,11 @@ def forward( ) if loss_ctx is None: raise NotImplementedError("loss_ctx must be provided for intra-layer bsz > 1") + if self.mtp_block is not None: + raise NotImplementedError( + "MTP is not supported in micro-batch forward mode (intra_layer_micro_batch > 1). " + "Please set intra_layer_micro_batch=1 when using MTP." + ) return self._micro_batch_forward( seq_ctx_list=seq_ctx, @@ -338,12 +398,8 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[MoEModelOutputs]) -> base_info = super().post_micro_batch_forward(batch_outputs) logs_info = base_info["logs_info"] - tokens_per_expert_global = torch.zeros( - self.config.num_hidden_layers - self.config.first_k_dense_replace, - self.config.n_routed_experts, - dtype=torch.int64, - device=DEVICE, - ) + first_tokens_per_expert = batch_outputs[0]["tokens_per_expert_global"] + tokens_per_expert_global = torch.zeros_like(first_tokens_per_expert) for output in batch_outputs: tokens_per_expert_global += output["tokens_per_expert_global"] @@ -362,7 +418,7 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[MoEModelOutputs]) -> def _micro_batch_forward( self, seq_ctx_list: list[SequenceContext], - loss_ctx_list: list[CELossContext], + loss_ctx_list: list[MoELossContextDict], return_router_logits: bool = False, ) -> MoEModelOutputs: """Micro-batch forward pass for MoE model. @@ -463,8 +519,10 @@ def _micro_batch_forward( cat_hidden_states = self.norm(cat_hidden_states) # Process final outputs for each micro-batch - cat_loss_ctx = CELossContext.cat(loss_ctx_list) - loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) + # Extract LM loss context from dict + lm_loss_ctx_list = [loss_ctx_dict["lm"] for loss_ctx_dict in loss_ctx_list] + cat_loss_ctx = type(lm_loss_ctx_list[0]).cat(lm_loss_ctx_list) + loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cast(LMHeadLossContext, cat_loss_ctx)) # Aggregate losses (mean across micro-batches) output["loss"] = loss.sum() @@ -473,6 +531,10 @@ def _micro_batch_forward( moe_extra_info.append(extra_info) output["extra_info"] = moe_extra_info + # MTP forward pass and loss computation for micro-batch mode + if self.mtp_block is not None: + raise NotImplementedError + # Handle router results for all micro-batches all_router_logits = [] all_router_weights = [] @@ -495,23 +557,35 @@ def _micro_batch_forward( combined_router_logits = torch.cat(all_router_logits, dim=1) # [num_layers, total_seq, num_experts] combined_router_weights = torch.cat(all_router_weights, dim=1) - # Calculate balancing loss across all micro-batches - batch_size = loss_ctx_list[0].batch_size if loss_ctx_list else 1 - if self.balancing_loss: - balancing_loss = ( - self.balancing_loss( - router_weights=combined_router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, + # Build balancing loss contexts + balancing_loss_ctx_list: list[BalancingLossContext] = [] + for loss_ctx_dict in loss_ctx_list: + bal_ctx = loss_ctx_dict.get("balancing") + if bal_ctx is not None: + balancing_loss_ctx_list.append(bal_ctx) + + if balancing_loss_ctx_list: + # Compute balancing loss by passing all parameters to forward + balancing_loss = sum( + ctx( + combined_router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, ) - / batch_size - * len(seq_ctx_list) + for ctx in balancing_loss_ctx_list ) output["balancing_loss"] = balancing_loss - # Calculate z-loss across all micro-batches - if self.z_loss: - z_loss = self.z_loss(router_logits=combined_router_logits) / batch_size * len(seq_ctx_list) + # Calculate z-loss across all micro-batches using loss context + z_loss_ctx_list: list[ZLossContext] = [] + for loss_ctx_dict in loss_ctx_list: + z_ctx = loss_ctx_dict.get("z_loss") + if z_ctx is not None: + z_loss_ctx_list.append(z_ctx) + + if z_loss_ctx_list: + # Compute z-loss by passing router_logits to forward + z_loss = sum(ctx(combined_router_logits) for ctx in z_loss_ctx_list) output["z_loss"] = z_loss # Calculate tokens per expert for bias update (if applicable) @@ -539,10 +613,42 @@ def _micro_batch_forward( return MoEModelOutputs(**output, logits=logits) + def _mtp_forward(self, mtp_config: MTPConfig, output, layer_hidden_states, position_embeddings, seq_ctx, mtp_seq_ctx, mtp_loss_ctx_dict): + # Forward through MTP block + name = mtp_config.name + + # Forward through MTP block + mtp_outputs = self.mtp_block[name]( + hidden_states=layer_hidden_states, + embed_tokens_fn=self.embed_tokens, + position_embeddings=position_embeddings, + seq_ctx=mtp_seq_ctx, + ) + + # Compute MTP losses for each depth + mtp_losses = torch.tensor(0.0, device=DEVICE) + mtp_loss_ctx_list = mtp_loss_ctx_dict[name] + for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): + mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx), mtp_config=mtp_config, layer_idx=idx) + mtp_losses += mtp_loss + + output["router_logits"][f"{name}_mtp_layer{idx}"] = mtp_router_results + output["router_weights"][f"{name}_mtp_layer{idx}"] = mtp_router_weights + + # Average MTP losses across depths and scale + mtp_losses = mtp_losses / len(mtp_loss_ctx_list) + scaled_mtp_loss = mtp_losses * mtp_config.loss_scaling_factor # type: ignore + + # Add to total loss + output[f"mtp_loss"][name] = scaled_mtp_loss + + return scaled_mtp_loss + def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, return_router_logits: bool = False, ) -> MoEModelOutputs: input_ids = seq_ctx.input_ids @@ -601,33 +707,56 @@ def _forward( if self.config.return_hidden_states: output["hidden_states"].append(hidden_states) + layer_hidden_states = hidden_states hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) # type: ignore + # Get LM loss context from dict + lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None + loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits output["extra_info"] = extra_info + # MTP forward pass and loss computation + if ( + self.mtp_block is not None + and loss_ctx is not None + and (mtp_loss_ctx_dict := loss_ctx.get("mtp")) is not None + ): + output["mtp_loss"] = {} + mtp_seq_ctx = seq_ctx.copy( + input_ids=input_ids.clone() if input_ids is not None else None, + position_ids=position_ids.clone(), + inputs_embeds=seq_ctx.inputs_embeds.clone() if seq_ctx.inputs_embeds is not None else None, + ) + + for mtp_config in self.config.mtp_config: + self._mtp_forward(mtp_config, output, layer_hidden_states, position_embeddings, seq_ctx, mtp_seq_ctx, mtp_loss_ctx_dict) + router_logits_list = list(output["router_logits"].values()) # type: ignore router_weights_list = list(output["router_weights"].values()) # type: ignore router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) router_weights = self._select_non_pad_router_logits(router_weights_list, seq_ctx.mask) - batch_size = loss_ctx.batch_size if loss_ctx is not None else 1 - if self.balancing_loss: - balancing_loss = ( - self.balancing_loss( - router_weights=router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, + # Calculate balancing loss using loss context + if loss_ctx is not None: + balancing_ctx = loss_ctx.get("balancing") + if balancing_ctx is not None: + # Compute balancing loss by passing all parameters to forward + balancing_loss = balancing_ctx( + router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, ) - / batch_size - ) - output["balancing_loss"] = balancing_loss + output["balancing_loss"] = balancing_loss - if self.z_loss: - z_loss = self.z_loss(router_logits=router_logits) / batch_size - output["z_loss"] = z_loss + # Calculate z-loss using loss context + if loss_ctx is not None: + z_loss_ctx = loss_ctx.get("z_loss") + if z_loss_ctx is not None: + # Compute z-loss by passing router_logits to forward + z_loss = z_loss_ctx(router_logits) + output["z_loss"] = z_loss tokens_per_expert_global = self._cal_tokens_per_expert(router_logits) output["tokens_per_expert_global"] = tokens_per_expert_global @@ -719,6 +848,81 @@ def build_rotary_embedding(self, config: MoEConfig) -> RotaryEmbeddingProtocol: with torch.device(DEVICE): return get_rope_embedding(config=config) + def build_mtp_block_dict(self, config): + mtp_block_dict = nn.ModuleDict() + for mtp_config in config.mtp_config: + mtp_block_dict[mtp_config.name] = self.build_mtp_block(config, mtp_config) + + return mtp_block_dict + + def build_mtp_block(self, config, mtp_config: MoEConfig) -> MTPBlock: + """Build MTP block with MoE decoder layers. + + Args: + config (MoEConfig): Model configuration. + + Returns: + MTPBlock: Constructed MTP block. + """ + # mtp_config = config.mtp_config + assert mtp_config is not None, "mtp_config must be provided" + + mtp_layers = [] + # Get attention config for MTP layers (use last layer's config) + last_layer_idx = config.num_hidden_layers - 1 + layers_type_list = config.layers_type + attention_config: MLAConfig | MHAConfig | GatedDeltaNetConfig + if layers_type_list[last_layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif layers_type_list[last_layer_idx] == "linear_attention": + assert config.linear_attention is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) + attention_config = config.linear_attention + else: + raise ValueError(f"Unsupported layer type {layers_type_list[last_layer_idx]}") + + for i in range(mtp_config.num_layers): + # Build MoE decoder layer for MTP + decoder_layer = MoEDecoderLayer( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + moe_intermediate_size=config.moe_intermediate_size, + mlp_bias=config.mlp_bias, + gate_bias=config.gate_bias, + moe_bias=config.moe_bias, + hidden_act=config.hidden_act, + rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, + num_experts_per_tok=config.num_experts_per_tok, + n_routed_experts=config.n_routed_experts, + n_shared_experts=config.n_shared_experts, + with_shared_expert_gate=config.with_shared_expert_gate, + hidden_factor=config.hidden_factor, + layer_type=layers_type_list[last_layer_idx], + attention_config=attention_config, + rope_scaling_cfg=config.rope_scaling_cfg, + generate_config=config.generate_config, + router_config=config.router, + moe_act_fn_cfg=config.moe_act_fn_cfg, + float8_cfg=config.float8_cfg, + layer_idx=config.num_hidden_layers + i, + dispatcher=config.dispatcher, + ep_mesh=self.ep_mesh, + ) + + # Wrap decoder layer in MTPLayer + mtp_layer = MTPLayer( + hidden_size=config.hidden_size, + rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, + decoder_layer=decoder_layer, + float8_cfg=config.float8_cfg, + ) + mtp_layers.append(mtp_layer) + + return MTPBlock(mtp_layers=mtp_layers) + @override def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple: # If model is built on meta device, we need to rebuild rotary embedding since from_hf will not @@ -782,15 +986,17 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) - num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): layer_idx = int(layer_idx) - if layer_idx < num_recompute_layers - 1: + if self._should_recompute( + layer_idx=layer_idx, + mtp_idx=None, + ): layer = checkpoint_wrapper(layer, checkpoint_impl=CheckpointImpl.REENTRANT) self.layers[str(layer_idx)] = layer - if layer_idx >= len(self.layers) - 1: + if layer_idx >= len(self.layers) - 1 and self.mtp_block is None: reshard_after_forward = False else: reshard_after_forward = self.fsdp_config.reshard_after_forward @@ -828,11 +1034,46 @@ def fully_shard( self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=lm_head_mp_policy, - reshard_after_forward=self.fsdp_config.reshard_after_forward, + reshard_after_forward=self.fsdp_config.reshard_after_forward if self.mtp_block is None else False, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, module=self.lm_head, ) + # Shard MTP block if it exists + if self.mtp_block is not None: + for mtp_name in self.mtp_block.keys(): + mtp_block = self.mtp_block[mtp_name] + for mtp_idx, mtp_layer in enumerate(mtp_block.layers): + if self._should_recompute(None, mtp_idx=mtp_idx): + mtp_layer = checkpoint_wrapper(mtp_layer, checkpoint_impl=CheckpointImpl.REENTRANT) + mtp_block.layers[mtp_idx] = mtp_layer + + reshard_after_forward = mtp_idx != len(mtp_block.layers) - 1 + self._fully_shard( + mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, + mp_policy=mp_policy, + reshard_after_forward=reshard_after_forward, + offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + module=mtp_layer, + ) + if mtp_idx == 0: + layer_next.set_modules_to_forward_prefetch([mtp_layer]) # type: ignore + + # if self.config.mtp_config is not None and self.config.mtp_config.num_layers > 0: + if self.config.mtp_config is not None: + mtp_block_layers = [] + for mtp_config in self.config.mtp_config: + mtp_block_layers.extend(list(self.mtp_block[mtp_config.name].layers)) + # for prev_mtp_layer, next_mtp_layer in zip( + # list(self.mtp_block.layers)[:-1], + # list(self.mtp_block.layers)[1:], + # ): + for prev_mtp_layer, next_mtp_layer in zip( + mtp_block_layers[:-1], + mtp_block_layers[1:], + ): + prev_mtp_layer.set_modules_to_forward_prefetch([next_mtp_layer]) # type: ignore + self._fully_shard( mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, @@ -996,19 +1237,79 @@ def patched_emb_forward(self, input): self.sparse, ) + def _should_recompute( + self, + layer_idx: int | None, + mtp_idx: int | None, + ) -> bool: + """Determine if a layer should use gradient checkpointing + (recomputation). + + The recomputation strategy treats decoder layers and MTP layers as a single + sequence. The recompute_ratio is applied to the total layer count. The last + layer in the entire model is never recomputed to avoid unnecessary overhead. + + Args: + layer_idx (int | None): Index of the decoder layer (0-based). None if this + is an MTP layer. + mtp_idx (int | None): Index of the MTP layer (0-based). None if this is a + decoder layer. + + Returns: + bool: True if the layer should use gradient checkpointing, False otherwise. + + Example: + Configuration: 7 decoder layers, 3 MTP layers, recompute_ratio=0.8 + - Total layers: 10 + - Recompute layers: int(10 * 0.8) = 8 + - Layer mapping: + * Decoder 0-6 → global index 0-6 (7 layers) + * MTP 0-2 → global index 7-9 (3 layers) + - Recomputation decision: + * Global 0-7 (decoder 0-6, MTP 0): recompute ✓ + * Global 8 (MTP 1): no recompute + * Global 9 (MTP 2, last layer): no recompute (forced) + """ + num_layers = self.config.num_hidden_layers + if self.config.mtp_config is None: + mtp_layers = 0 + else: + mtp_layers = sum([mtp_config.num_layers for mtp_config in self.config.mtp_config]) + # mtp_layers = self.config.mtp_config.num_layers if self.config.mtp_config is not None else 0 + recompute_ratio = self.fsdp_config.recompute_ratio if self.fsdp_config is not None else 0.0 + + total_layers = num_layers + mtp_layers + num_recompute_layers = int(total_layers * recompute_ratio) + + # Determine the global layer index (0-based) + if layer_idx is not None: + # This is a decoder layer + global_idx = layer_idx + else: + # This is an MTP layer (comes after all decoder layers) + assert mtp_idx is not None, "Either layer_idx or mtp_idx must be provided" + global_idx = num_layers + mtp_idx + + # Last layer is never recomputed + if global_idx == total_layers - 1: + return False + + # Recompute if within the recompute range + return global_idx < num_recompute_layers + # NOTE: Add this overload for inferring the return type for easier type checking and using @overload # type: ignore def __call__( # type: ignore self, seq_ctx: SequenceContext, - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, ) -> MoEModelOutputs: ... @overload # type: ignore def __call__( # type: ignore self, seq_ctx: list[SequenceContext], - loss_ctx: list[CELossContext], + loss_ctx: list[MoELossContextDict], ) -> MoEModelOutputs: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 6ff2184129..89040f3df6 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -43,6 +43,74 @@ class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: + # Handle MTP parameters + if key.startswith("mtp_block."): + # Remove "mtp_block." prefix + key = key.replace("mtp_block.", "", 1) + + # Handle MTP layer-specific parameters + # xtuner: mtp_block.layers.{idx}.decoder_layer.{param} + # HF: mtp.layers.{idx}.{param} + key = re.sub(r"layers\.(\d+)\.decoder_layer\.", r"layers.\1.", key) + + # Handle MTP normalization layers + # xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding + # xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden + # xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm + # Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers + # if ".enorm." in key: + # key = re.sub(r"layers\.\d+\.enorm\.", "pre_fc_norm_embedding.", key) + # elif ".hnorm." in key: + # key = re.sub(r"layers\.\d+\.hnorm\.", "pre_fc_norm_hidden.", key) + # elif ".final_layernorm." in key: + # key = re.sub(r"layers\.\d+\.final_layernorm\.", "norm.", key) + + # Handle MTP normalization layers + # xtuner: mtp_block.{mtp_key}.layers.{idx}.enorm -> HF: mtp.{mtp_key}.layers.{idx}.pre_fc_norm_embedding + # xtuner: mtp_block.{mtp_key}.layers.{idx}.hnorm -> HF: mtp.{mtp_key}.layers.{idx}.pre_fc_norm_hidden + # xtuner: mtp_block.{mtp_key}.layers.{idx}.final_layernorm -> HF: mtp.{mtp_key}.layers.{idx}.norm + if ".enorm." in key: + key = re.sub(r"layers\.(\d+)\.enorm\.", r"layers.\1.pre_fc_norm_embedding.", key) + elif ".hnorm." in key: + key = re.sub(r"layers\.(\d+)\.hnorm\.", r"layers.\1.pre_fc_norm_hidden.", key) + elif ".final_layernorm." in key: + key = re.sub(r"layers\.(\d+)\.final_layernorm\.", r"layers.\1.norm.", key) + + + # Handle MTP projection layer + # xtuner: mtp_block.{mtp_key}.layers.{idx}.eh_proj -> HF: mtp.fc + # if ".eh_proj." in key: + # key = re.sub(r"layers\.\d+\.eh_proj\.", "fc.", key) + + # Handle MTP projection layer + # xtuner: mtp_block.{mtp_key}.layers.{idx}.eh_proj -> HF: mtp.{mtp_key}.layers.{idx}.fc + if ".eh_proj." in key: + key = re.sub(r"layers\.(\d+)\.eh_proj\.", r"layers.\1.fc.", key) + + # Handle MoE-specific transformations within MTP layers + key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) + key = key.replace("shared_experts", "shared_expert") + + # Handle fused weights + n_routed_experts = self.config.n_routed_experts + if "fused_w1w3.weight" in key: + w1w3_keys: list[str] = [] + + for i in range(n_routed_experts): + w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.gate_proj.weight")) + w1w3_keys.append(key.replace("fused_w1w3.weight", f"{i}.up_proj.weight")) + + return [f"mtp.{key}" for key in w1w3_keys] + + elif "fused_w2.weight" in key: + w2_keys: list[str] = [] + for i in range(n_routed_experts): + w2_keys.append(key.replace("fused_w2.weight", f"{i}.down_proj.weight")) + return [f"mtp.{key}" for key in w2_keys] + else: + return ["mtp." + key] + + # Handle main model parameters if "layers" in key or "embed_tokens" in key: key = "model.language_model." + key @@ -86,13 +154,13 @@ def safetensors_to_params( else: loaded_tensor = safetensors[0] - if "fused_w1w3.weight" in param_name: + if "fused_w1w3.weight" in param_name and "mtp" not in param_name: # hf: num_experts, 2 * expert_dim, hidden_size # xtuner: num_experts * 2 * expert_dim, hidden_size # num_experts * 2 * expert_dim, hidden_size loaded_tensor = loaded_tensor.flatten(0, 1) - elif "fused_w2.weight" in param_name: + elif "fused_w2.weight" in param_name and "mtp" not in param_name: # hf: num_experts, hidden_size, expert_dim # xtuner: num_experts * hidden_size, expert_dim loaded_tensor = loaded_tensor.flatten(0, 1) @@ -117,6 +185,9 @@ def param_to_safetensor( safetensor: torch.Tensor, hf_param_name: str, ): + if "mtp" in hf_param_name: + return safetensor + assert isinstance(hf_param_name, str) if "gate_up_proj" in hf_param_name: # xtuner: num_experts * 2 * expert_dim, hidden_size diff --git a/xtuner/v1/model/moe/qwen3vl_text.py b/xtuner/v1/model/moe/qwen3vl_text.py index 1996edcf75..e0fac316e3 100644 --- a/xtuner/v1/model/moe/qwen3vl_text.py +++ b/xtuner/v1/model/moe/qwen3vl_text.py @@ -4,10 +4,9 @@ import torch from xtuner.v1.data_proto import SequenceContext -from xtuner.v1.loss import CELossContext from xtuner.v1.utils.activation_offload import async_save_on_cpu -from .moe import MoEModelOutputs +from .moe import MoELossContextDict, MoEModelOutputs from .qwen3 import Qwen3MoE, Qwen3MoE30BA3Config, Qwen3MoE235BA22Config @@ -112,9 +111,12 @@ def _deepstack_process( def _forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: CELossContext | None, + loss_ctx: MoELossContextDict | None, return_router_logits: bool = False, ) -> MoEModelOutputs: + if seq_ctx.deepstack_visual_embeds is None: + return super()._forward(seq_ctx, loss_ctx, return_router_logits) + input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -183,7 +185,9 @@ def _forward( hidden_states = self.norm(hidden_states) - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx) # type: ignore + # Get LM loss context from dict + lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None + loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx) # type: ignore output["loss"] = loss output["logits"] = logits output["extra_info"] = extra_info @@ -193,17 +197,23 @@ def _forward( router_logits = self._select_non_pad_router_logits(router_logits_list, seq_ctx.mask) router_weights = self._select_non_pad_router_logits(router_weights_list, seq_ctx.mask) - if self.balancing_loss: - balancing_loss = self.balancing_loss( - router_weights=router_weights, - n_routed_experts=self.config.n_routed_experts, - num_experts_per_tok=self.config.num_experts_per_tok, - ) - output["balancing_loss"] = balancing_loss - - if self.z_loss: - z_loss = self.z_loss(router_logits=router_logits) - output["z_loss"] = z_loss + # Calculate balancing loss using loss context + if loss_ctx is not None: + balancing_ctx = loss_ctx.get("balancing") + if balancing_ctx is not None: + balancing_loss = balancing_ctx( + router_weights, + self.config.n_routed_experts, + self.config.num_experts_per_tok, + ) + output["balancing_loss"] = balancing_loss + + # Calculate z-loss using loss context + if loss_ctx is not None: + z_loss_ctx = loss_ctx.get("z_loss") + if z_loss_ctx is not None: + z_loss = z_loss_ctx(router_logits) + output["z_loss"] = z_loss tokens_per_expert_global = self._cal_tokens_per_expert(router_logits) output["tokens_per_expert_global"] = tokens_per_expert_global diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index 92cd96c2a4..ea2ed5577e 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -6,7 +6,7 @@ from torch.distributed.tensor import DTensor from typing_extensions import overload -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import LMHeadLossContext Loss: TypeAlias = torch.Tensor @@ -25,11 +25,12 @@ def forward( @overload # type: ignore[override] def forward( - self, hidden_states: HiddenStates, loss_ctx: CELossContext + self, hidden_states: HiddenStates, loss_ctx: LMHeadLossContext ) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ... def forward( # type: ignore[override] - self, hidden_states: torch.Tensor, loss_ctx: CELossContext | None = None + self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None, + mtp_config = None, layer_idx: int = 0, ) -> tuple[Loss | None, tuple[Logits | None, dict[str, Any]]]: """Forward pass of the language model head.""" if isinstance(self.weight, DTensor): @@ -46,7 +47,7 @@ def forward( # type: ignore[override] logits = F.linear(hidden_states, w, b) return None, (logits.float(), {}) else: - return loss_ctx.forward(hidden_states, w, b) + return loss_ctx.forward(hidden_states, w, b, mtp_config, layer_idx) @overload # type: ignore def __call__( @@ -55,7 +56,7 @@ def __call__( @overload # type: ignore def __call__( - self, hidden_states: HiddenStates, loss_ctx: CELossContext + self, hidden_states: HiddenStates, loss_ctx: LMHeadLossContext ) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/module/mtp/__init__.py b/xtuner/v1/module/mtp/__init__.py new file mode 100644 index 0000000000..8ced4cbaae --- /dev/null +++ b/xtuner/v1/module/mtp/__init__.py @@ -0,0 +1,7 @@ +from .config import MTPConfig +from .mtp_block import MTPBlock +from .mtp_layer import MTPLayer +from .utils import roll_packed_tensor, roll_sequence_context + + +__all__ = ["MTPConfig", "MTPBlock", "MTPLayer", "roll_packed_tensor", "roll_sequence_context"] diff --git a/xtuner/v1/module/mtp/config.py b/xtuner/v1/module/mtp/config.py new file mode 100644 index 0000000000..efa3084ebe --- /dev/null +++ b/xtuner/v1/module/mtp/config.py @@ -0,0 +1,47 @@ +"""Configuration for Multi-Token Prediction (MTP).""" + +from typing import Annotated, Sequence + +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict + + +class MTPConfig(BaseModel): + """Configuration for Multi-Token Prediction (MTP). + + MTP extends the prediction scope to multiple future tokens at each position, + creating denser training signals and potentially improving data efficiency. + + This config only contains training-related hyperparameters. The actual + construction of MTP layers (including choosing Dense vs MoE decoder layers) + is handled by the model (Dense/MoE) which knows how to create the appropriate + decoder layers. + + Args: + num_layers (int): Number of MTP layers (prediction depths). Each layer + predicts tokens at increasing future positions (i+1, i+2, ..., i+D). + loss_scaling_factor (float): Scaling factor for MTP loss. The total MTP loss + is computed as the average of losses across all depths, multiplied by + this factor. Default: 0.1. + + Example: + >>> # In model config + >>> config = TransformerConfig( + ... ..., + ... mtp_config=MTPConfig( + ... num_layers=2, + ... loss_scaling_factor=0.1, + ... ), + ... ) + """ + + model_config = ConfigDict(extra="forbid") + + name: Annotated[str, Parameter(group="model")] + num_layers: Annotated[int, Parameter(group="model")] + loss_scaling_factor: Annotated[float, Parameter(group="model")] = 0.1 + + mask_type: Annotated[str | None, Parameter(group="model")] + # mask_type v2 + open_token_list: Annotated[Sequence[int], Parameter(help="Open tokens for mask", group="model")] = [] + close_token_list: Annotated[Sequence[int], Parameter(help="End tokens for mask", group="model")] = [] diff --git a/xtuner/v1/module/mtp/mtp_block.py b/xtuner/v1/module/mtp/mtp_block.py new file mode 100644 index 0000000000..08960508b7 --- /dev/null +++ b/xtuner/v1/module/mtp/mtp_block.py @@ -0,0 +1,125 @@ +"""Multi-Token Prediction (MTP) Block implementation.""" + +from typing import Callable + +import torch +import torch.nn as nn + +from xtuner.v1.data_proto import SequenceContext + +from .mtp_layer import MTPLayer +from .utils import roll_sequence_context + + +class MTPBlock(nn.Module): + """Multi-Token Prediction (MTP) block containing multiple MTP layers. + + This block manages D sequential MTP layers, where each layer predicts + a future token at increasing depths (i+1, i+2, ..., i+D). + + The k-th layer receives: + - Hidden states from the (k-1)-th layer + - Embeddings of tokens at position (i+k) + + This forms a sequential prediction chain where deeper layers build upon + the predictions of shallower layers. + + Args: + mtp_layers (list[MTPLayer]): List of MTP layers. Each layer should be a + fully constructed MTPLayer instance. The number of layers determines + the prediction depth (D). + + Example: + >>> # Build MTP layers (typically done by Dense/MoE model) + >>> mtp_layers = [] + >>> for i in range(2): + ... decoder_layer = build_decoder_layer(...) + ... mtp_layer = MTPLayer( + ... hidden_size=512, + ... rms_norm_eps=1e-6, + ... rms_norm_type="default", + ... decoder_layer=decoder_layer, + ... ) + ... mtp_layers.append(mtp_layer) + >>> + >>> # Create MTP block + >>> mtp_block = MTPBlock(mtp_layers=mtp_layers) + >>> + >>> # Forward pass + >>> outputs = mtp_block( + ... hidden_states=h, + ... input_ids=ids, + ... position_ids=pos, + ... embed_tokens_fn=embed_fn, + ... position_embeddings=pos_emb, + ... seq_ctx=ctx, + ... ) + >>> # outputs[0]: predictions for i+1 + >>> # outputs[1]: predictions for i+2 + """ + + def __init__(self, *, mtp_layers: list[MTPLayer]): + super().__init__() + if not mtp_layers: + raise ValueError("mtp_layers cannot be empty") + + self.layers = nn.ModuleList(mtp_layers) + self.num_layers = len(mtp_layers) + + def forward( + self, + hidden_states: torch.Tensor, + embed_tokens_fn: Callable[[torch.Tensor], torch.Tensor], + position_embeddings: tuple[torch.Tensor, torch.Tensor], + seq_ctx: SequenceContext, + ) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """Forward pass through all MTP layers. + + Args: + hidden_states (torch.Tensor): Hidden states from the main model, + shape [batch, seq_len, hidden_size]. + embed_tokens_fn (Callable): Function to embed tokens. Takes token IDs + and returns embeddings. Should have signature: + embed_tokens_fn(token_ids: Tensor) -> Tensor + position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position + embeddings (cos, sin). + seq_ctx (SequenceContext): Sequence context containing input_ids, position_ids, + attention mask, etc. + + Returns: + list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: List of 3-tuples + (hidden_states, router_weights, router_results) for each MTP depth. + Length equals num_layers. + - outputs[0]: Outputs for predicting token at position (i+1) + - outputs[k]: Outputs for predicting token at position (i+k+1) + """ + mtp_outputs = [] + current_hidden_states = hidden_states + current_seq_ctx = seq_ctx + + for layer in self.layers: + # Roll sequence context to get future tokens + # This shifts each packed sequence independently, respecting boundaries + current_seq_ctx = roll_sequence_context(current_seq_ctx, shifts=-1) + + # Get embeddings for future tokens + if current_seq_ctx.inputs_embeds is None: + future_embeddings = embed_tokens_fn(current_seq_ctx.input_ids) # type: ignore[arg-type] + else: + future_embeddings = current_seq_ctx.inputs_embeds + + # Forward through MTP layer + output = layer( + hidden_states=current_hidden_states, + future_embeddings=future_embeddings, + position_embeddings=position_embeddings, + seq_ctx=current_seq_ctx, + ) + if isinstance(output, tuple): + current_hidden_states, router_results, router_weights = output + else: + current_hidden_states = output + # Save output for this depth + mtp_outputs.append(output) + + return mtp_outputs diff --git a/xtuner/v1/module/mtp/mtp_layer.py b/xtuner/v1/module/mtp/mtp_layer.py new file mode 100644 index 0000000000..8099b76b47 --- /dev/null +++ b/xtuner/v1/module/mtp/mtp_layer.py @@ -0,0 +1,126 @@ +"""Multi-Token Prediction (MTP) Layer implementation.""" + +from typing import Literal + +import torch +import torch.nn as nn + +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.module import RMSNorm +from xtuner.v1.module.linear import build_linear + + +class MTPLayer(nn.Module): + """Single Multi-Token Prediction (MTP) layer. + + MTP Layer wraps a standard decoder layer with MTP-specific preprocessing + and postprocessing. The structure is: + + [enorm + hnorm + projection] → [DecoderLayer] → [final_layernorm] + + The k-th MTP layer predicts the (i+k)-th token by combining: + 1. Hidden states from the previous MTP layer (or main model) + 2. Embedding of the future token at position (i+k) + + Note: The decoder layer's internal normalization (input_layernorm) is preserved + for simplicity and modularity. While this adds a small computational overhead, + it allows MTP to work with any decoder layer implementation (Dense, MoE, etc.) + without modification. + + Args: + hidden_size (int): Hidden dimension size. + rms_norm_eps (float): Epsilon for RMSNorm. + rms_norm_type (str): Type of RMSNorm ("default" or "zero_centered"). + decoder_layer (nn.Module): A fully constructed decoder layer instance. + This can be DenseDecoderLayer, MoEDecoderLayer, or any custom decoder layer + that implements the standard forward signature. + float8_cfg: Float8 configuration for the projection layer. + + Example: + >>> from xtuner.v1.module.decoder_layer import DenseDecoderLayer + >>> decoder_layer = DenseDecoderLayer( + ... hidden_size=512, + ... intermediate_size=2048, + ... ... + ... ) + >>> mtp_layer = MTPLayer( + ... hidden_size=512, + ... rms_norm_eps=1e-6, + ... rms_norm_type="default", + ... decoder_layer=decoder_layer, + ... ) + """ + + def __init__( + self, + *, + hidden_size: int, + rms_norm_eps: float, + rms_norm_type: Literal["default", "zero_centered"], + decoder_layer: nn.Module, + float8_cfg=None, + ): + super().__init__() + self.hidden_size = hidden_size + + # MTP-specific preprocessing components + self.enorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + self.hnorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + self.eh_proj = build_linear( + hidden_size * 2, + hidden_size, + bias=False, + float8_cfg=float8_cfg, + ) + + # Core decoder layer (Dense, MoE, or any custom implementation) + self.decoder_layer = decoder_layer + + # MTP-specific postprocessing component + self.final_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + + def forward( + self, + hidden_states: torch.Tensor, + future_embeddings: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + seq_ctx: SequenceContext, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass through the MTP layer. + + Args: + hidden_states (torch.Tensor): Hidden states from previous layer, + shape [batch, seq_len, hidden_size]. + future_embeddings (torch.Tensor): Embeddings of future tokens, + shape [batch, seq_len, hidden_size]. + position_embeddings (tuple[torch.Tensor, torch.Tensor]): Rotary position + embeddings (cos, sin). + seq_ctx (SequenceContext): Sequence context containing attention mask, etc. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A 3-tuple of + (hidden_states, router_weights, router_results) where each tensor + has shape [batch, seq_len, ...]. + """ + # Step 1: Normalize embeddings and hidden states separately + # This ensures both inputs are in the same numerical range + normalized_embedding = self.enorm(future_embeddings) + normalized_hidden = self.hnorm(hidden_states) + + # Step 2: Concatenate and project to combine information + # [B, S, H] + [B, S, H] → [B, S, 2H] → [B, S, H] + combined = torch.cat([normalized_embedding, normalized_hidden], dim=-1) + projected = self.eh_proj(combined) + + # Step 3: Pass through the standard decoder layer + # This includes attention, MLP, and their respective normalizations + # TODO: TMP hardcode here. + hidden_states, router_results, router_weights = self.decoder_layer( + projected, + position_embeddings=position_embeddings, + seq_ctx=seq_ctx, + ) + + # Step 4: Final normalization before output + hidden_states = self.final_layernorm(hidden_states) + return hidden_states, router_results, router_weights diff --git a/xtuner/v1/module/mtp/utils.py b/xtuner/v1/module/mtp/utils.py new file mode 100644 index 0000000000..15425449dc --- /dev/null +++ b/xtuner/v1/module/mtp/utils.py @@ -0,0 +1,131 @@ +"""Utility functions for Multi-Token Prediction (MTP).""" + +import torch + +from xtuner.v1.data_proto import SequenceContext + + +def roll_packed_tensor( + tensor: torch.Tensor, + cu_seq_lens: torch.IntTensor, + shifts: int = -1, + dim: int = -1, + fill_value: float | int = 0, +) -> torch.Tensor: + """Roll a packed tensor along the specified dimension. + + This function respects sequence boundaries in packed sequences, shifting each + sequence independently without crossing boundaries. + + Args: + tensor (torch.Tensor): Input packed tensor to roll. + cu_seq_lens (torch.IntTensor): Cumulative sequence lengths defining packed + sequence boundaries. Shape [num_sequences + 1]. + shifts (int): Number of positions to shift. Use -1 for left shift (default). + Only negative shifts are supported. + dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries + are applied on this dimension. Default is -1 (last dimension). + fill_value (float | int): Value used to fill boundary positions after rolling. + Defaults to 0. Use the loss ignore index (e.g., -100) when rolling label + tensors to ensure boundary positions are excluded from loss computation. + + Returns: + torch.Tensor: Rolled tensor with boundary positions filled with ``fill_value``. + + Example: + For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1: + >>> tensor = torch.tensor([[1, 2, 3, 4, 5, 6]]) + >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32) + >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-1) + >>> rolled # [[2, 3, 0, 5, 6, 0]] + + For a 3D tensor with dim=-2 (e.g., inputs_embeds of shape [1, seq_len, hidden]): + >>> tensor = torch.arange(12).reshape(1, 6, 2) + >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32) + >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2) + >>> rolled[0, 2] # tensor([0, 0]) (boundary filled with fill_value=0) + """ + assert shifts <= 0, "Only negative shift is supported" + + # Normalize dim to a positive index + dim = dim % tensor.dim() + + rolled_tensor = tensor.clone() + + # Roll each packed sequence independently within its boundaries + for i in range(len(cu_seq_lens) - 1): + start_idx = cu_seq_lens[i].item() + end_idx = cu_seq_lens[i + 1].item() + + # Extract sequence slice along the specified dimension + seq_slice = tensor.narrow(dim, start_idx, end_idx - start_idx) # type: ignore[arg-type] + rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dim) + + # Fill the last |shifts| positions along dim to avoid information + # leakage across sequences. For shifts=-1 the last 1 position is + # zeroed; for shifts=-2 the last 2 positions are zeroed, etc. + zero_len = -shifts + zero_len = min(zero_len, (end_idx - start_idx)) + zero_start = (end_idx - start_idx) - zero_len + zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type] + zero_slice.zero_() + + # Write back to the rolled tensor + rolled_tensor.narrow(dim, start_idx, end_idx - start_idx).copy_(rolled_seq) # type: ignore[arg-type] + + return rolled_tensor + + +def roll_sequence_context( + seq_ctx: SequenceContext, + shifts: int = -1, +) -> SequenceContext: + """Roll the sequence context to get future tokens for MTP prediction. + + This function respects sequence boundaries in packed sequences, shifting each + sequence independently without crossing boundaries. Returns a new + ``SequenceContext`` — the original is never modified. + + Args: + seq_ctx (SequenceContext): Input sequence context with packed sequences. + shifts (int): Number of positions to shift. Use -1 for left shift (default). + Only -1 is currently supported. + + Returns: + SequenceContext: A new sequence context with shifted input_ids (and/or + inputs_embeds). Positions at sequence boundaries are zeroed to prevent + information leakage. + + Example: + For packed sequences [1,2,3] and [4,5,6] with shifts=-1: + Original input_ids: [1, 2, 3, 4, 5, 6] + Rolled input_ids: [2, 3, 0, 5, 6, 0] + """ + sp_mesh = seq_ctx.sequence_parallel_mesh + is_sp = sp_mesh is not None and sp_mesh.size() > 1 + + overrides: dict = {} + + raw_input_ids = seq_ctx.raw_input_ids + if raw_input_ids is not None: + rolled = roll_packed_tensor(tensor=raw_input_ids, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-1) + overrides["raw_input_ids"] = rolled + if is_sp: + s = seq_ctx._shard_start + overrides["input_ids"] = rolled[:, s : s + seq_ctx._shard_size] + else: + overrides["input_ids"] = rolled + + raw_inputs_embeds = seq_ctx.raw_inputs_embeds + if raw_inputs_embeds is not None: + rolled_e = roll_packed_tensor( + tensor=raw_inputs_embeds, cu_seq_lens=seq_ctx.cu_seq_lens_q, shifts=shifts, dim=-2 + ) + overrides["raw_inputs_embeds"] = rolled_e + if is_sp: + s = seq_ctx._shard_start + overrides["inputs_embeds"] = rolled_e[:, s : s + seq_ctx._shard_size] + else: + overrides["inputs_embeds"] = rolled_e + + return seq_ctx.copy(**overrides) diff --git a/xtuner/v1/rl/base/loss.py b/xtuner/v1/rl/base/loss.py index 8164538b9d..006f30ca75 100644 --- a/xtuner/v1/rl/base/loss.py +++ b/xtuner/v1/rl/base/loss.py @@ -4,8 +4,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Self -from xtuner.v1.loss import BaseLossConfig, BaseLossKwargs -from xtuner.v1.loss.base_loss_ctx import BaseLossContext +from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext, CELossKwargs from xtuner.v1.loss.utils import sp_gather, sp_split from xtuner.v1.utils.device import get_device @@ -24,7 +23,7 @@ def compute_kl_loss_weight( return kl_loss_weight -class BaseRLLossConfig(BaseLossConfig): +class BaseRLLossConfig(CELossConfig): """Base configuration for reinforcement learning loss functions in XTuner RL. @@ -96,14 +95,36 @@ def _loss_kwargs_cls(self) -> type["BaseRLLossKwargs"]: def build( self, - sp_mesh: DeviceMesh | None, - shifted_labels: torch.Tensor, - advantages: torch.Tensor, - rollout_logprobs: torch.Tensor | None = None, - old_logprobs: torch.Tensor | None = None, - rollout_is_weights: torch.Tensor | None = None, - ref_logprobs: torch.Tensor | None = None, - ) -> "BaseRLLossContext": + data: dict, + sp_mesh: DeviceMesh | None = None, + ) -> "BaseRLLossContext | None": + """Build RL loss context from data dict. + + Args: + data (dict): Data dictionary containing RL-specific fields: + - shifted_labels (torch.Tensor): The shifted labels + - advantages (torch.Tensor): Advantage estimates + - rollout_logprobs (torch.Tensor | None): Rollout log probabilities + - old_logprobs (torch.Tensor | None): Old policy log probabilities (optional, can be set later) + - rollout_is_weights (torch.Tensor | None): Importance sampling weights + - ref_logprobs (torch.Tensor | None): Reference model log probabilities + sp_mesh (DeviceMesh | None): Sequence parallel device mesh + + Returns: + BaseRLLossContext | None: The built loss context, or None if required fields are missing + """ + # Check for required fields + if "shifted_labels" not in data or "advantages" not in data: + return None + + # Extract RL-specific fields from data + shifted_labels = data["shifted_labels"] + advantages = data["advantages"] + rollout_logprobs = data.get("rollout_logprobs", None) + old_logprobs = data.get("old_logprobs", None) + rollout_is_weights = data.get("rollout_is_weights", None) + ref_logprobs = data.get("ref_logprobs", None) + LossKwargs = self._loss_kwargs_cls loss_kwargs = LossKwargs( shifted_labels=shifted_labels, @@ -120,7 +141,7 @@ def build( return LossContext(self, loss_kwargs) -class BaseRLLossKwargs(BaseLossKwargs): +class BaseRLLossKwargs(CELossKwargs): """Keyword arguments for reinforcement learning loss computation. Args: @@ -144,7 +165,9 @@ class BaseRLLossKwargs(BaseLossKwargs): is_weights: torch.Tensor | None = None def sp_split(self, sp_mesh: DeviceMesh) -> Self: - self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) + # Call parent class to handle shifted_labels + super().sp_split(sp_mesh) + # Handle RL-specific fields self.advantages = sp_split(self.advantages, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0) if self.rollout_logprobs is not None: self.rollout_logprobs = sp_split(self.rollout_logprobs, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0) @@ -158,7 +181,9 @@ def sp_split(self, sp_mesh: DeviceMesh) -> Self: return self def to(self, device: torch.device | str) -> Self: - self.shifted_labels = self.shifted_labels.to(device) + # Call parent class to handle shifted_labels + super().to(device) + # Handle RL-specific fields self.advantages = self.advantages.to(device) if self.old_logprobs is not None: self.old_logprobs = self.old_logprobs.to(device) @@ -177,9 +202,9 @@ def to(self, device: torch.device | str) -> Self: return self -class BaseRLLossContext(BaseLossContext): - loss_cfg: BaseRLLossConfig - loss_kwargs: BaseRLLossKwargs +class BaseRLLossContext(CELossContext): + loss_cfg: BaseRLLossConfig # type: ignore[assignment] + loss_kwargs: BaseRLLossKwargs # type: ignore[assignment] def compute_rollout_is( self, sp_mesh: DeviceMesh, num_tokens: torch.Tensor diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 08a96620c3..27bf94a178 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -377,7 +377,8 @@ def compute_actor_logprobs( self._engine._maybe_precompute_float8_dynamic_scale_for_fsdp() old_logprobs_list: list[torch.Tensor] = [] for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list): - loss_ctx = self.logprob_cfg.build(shifted_labels=shifted_labels) + loss_ctx = self.logprob_cfg.build(data={"shifted_labels": shifted_labels}) + assert loss_ctx is not None output = self._engine.forward_only(seq_ctx=seq_ctx, loss_ctx=loss_ctx) old_logprobs_list.append(output["loss"]) return old_logprobs_list @@ -501,10 +502,16 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo rollout_logprobs = data.get("rollout_logprobs", None) rollout_logprobs = rollout_logprobs.to(DEVICE) if rollout_logprobs is not None else None loss_ctx = loss_cfg.build( - self.sp_mesh, shifted_labels=shifted_labels, advantages=advantages, rollout_logprobs=rollout_logprobs + data={ + "shifted_labels": shifted_labels, + "advantages": advantages, + "rollout_logprobs": rollout_logprobs, + }, + sp_mesh=self.sp_mesh, ) seq_ctx_list.append(seq_ctx) + assert loss_ctx is not None loss_ctx_list.append(loss_ctx) del data_batches @@ -596,8 +603,9 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo LossContext = loss_cfg.loss_ctx_cls for i in range(0, len(loss_ctx_list), iters_per_step): batches_loss_ctx = loss_ctx_list[i : i + iters_per_step] - batches_loss_ctx = LossContext.build_batches(batches_loss_ctx) - batched_loss_ctx_list.extend(batches_loss_ctx) + batched_loss_ctx_list.extend( + LossContext.build_batches(batches_loss_ctx) # type: ignore[arg-type] + ) # train optimizer steps for i in range(0, len(seq_ctx_list), iters_per_step): @@ -605,7 +613,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo batches_loss_ctx = batched_loss_ctx_list[i : i + iters_per_step] engine_input = [ - ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) + ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) # type: ignore[typeddict-item] for seq_ctx, loss_ctx in zip(batches_seq_ctx, batches_loss_ctx) ] @@ -707,7 +715,7 @@ def _train_one_step_sft(self, data_batch): if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh) + loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=self.sp_mesh) loss_ctx_list.append(loss_ctx) del data_batch diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 38d46b3a09..84f78d714b 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -32,7 +32,7 @@ from xtuner.v1.datasets.config import BaseDataloaderConfig, DataloaderConfig, DatasetConfigList from xtuner.v1.engine import TrainEngine from xtuner.v1.engine.train_engine import TrainStepInfo -from xtuner.v1.loss import CELossConfig, CELossContext +from xtuner.v1.loss import CELossConfig from xtuner.v1.model.base import ModelItem, XTunerBaseModelConfig from xtuner.v1.model.moe.moe import MoEConfig from xtuner.v1.patch import patch_default_save_plan @@ -579,8 +579,19 @@ def __init__( global_batch_size = self.data_mesh["dp"].size() self._global_batch_size = global_batch_size + self._resolve_model_loss_cfg(model_cfg, loss_cfg) + + if loss_cfg is None: + loss_cfg = CELossConfig() + self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg, fsdp_cfg) + if intra_layer_micro_batch > 1 and isinstance(model_cfg, MoEConfig) and model_cfg.mtp_config is not None: + raise ValueError( + "MTP (Multi-Token Prediction) is not supported with intra_layer_micro_batch > 1. " + f"Got intra_layer_micro_batch={intra_layer_micro_batch} and mtp_config={model_cfg.mtp_config}." + ) + if dataset_cfg is not None: # TODO: Removed in version 1.1.0 logger.warning("`dataset_cfg` is deprecated, please use `dataloader_cfg.dataset_config_list` instead") # For backward compatibility, reserve the dataset_cfg interface, remove it later @@ -618,8 +629,6 @@ def __init__( self._lr_cfg = lr_cfg self._lr_scheduler = self.build_lr_scheduler(lr_cfg, self.total_step) - if loss_cfg is None: - loss_cfg = CELossConfig() self.loss_cfg = loss_cfg if debug: @@ -784,28 +793,26 @@ def fit(self): self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds") def _prepare_model_input(self, data_batch) -> list[ModelItem]: - loss_cfg: CELossConfig = self.loss_cfg seq_ctx_list: list[SequenceContext] = [] - loss_ctx_list: list[CELossContext] = [] + # 1. Extract seq_ctx for data in data_batch: - seq_ctx = data.pop("seq_ctx").to(DEVICE) + seq_ctx = data["seq_ctx"].to(DEVICE) if self.sp_mesh.size() > 1: seq_ctx = seq_ctx.split(sequence_parallel_mesh=self.sp_mesh) seq_ctx_list.append(seq_ctx) - loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=self.sp_mesh) - loss_ctx_list.append(loss_ctx) + + # 2. Compute cu_seq_lens_list (for calibration) + # 3. Call model's interface to build and calibrate all loss_ctx (done in one shot) + loss_ctx_dict_list = self._engine.model.build_loss_ctx_batch(data_batch, sp_mesh=self.sp_mesh) # TODO: Consider moving data_batch deletion to the caller for better memory management. del data_batch - cu_seq_lens_list = [seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list] - loss_ctx_list = CELossContext.build_batches( - loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=self.sp_mesh - ) - + # 4. Return ModelItem engine_input = [ - ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx) for seq_ctx, loss_ctx in zip(seq_ctx_list, loss_ctx_list) + ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx_dict) + for seq_ctx, loss_ctx_dict in zip(seq_ctx_list, loss_ctx_dict_list) ] return engine_input @@ -1721,6 +1728,24 @@ def _resolve_deprecated_resume_cfg(self, resume_cfg: ResumeConfig, auto_resume: return True return auto_resume + def _resolve_model_loss_cfg(self, model_cfg: XTunerBaseModelConfig, loss_cfg: CELossConfig | None): + """Backward compatibility: set Trainer's loss_cfg to model's lm_loss_cfg if not already set. + + Args: + model_cfg (XTunerBaseModelConfig): Model configuration + loss_cfg (CELossConfig): Loss configuration from Trainer + """ + if loss_cfg is not None: + if hasattr(model_cfg, "text_config"): + model_cfg.text_config.lm_loss_cfg = loss_cfg + else: + model_cfg.lm_loss_cfg = loss_cfg + if self.rank == 0: + logger.warning( + "Setting model_cfg.lm_loss_cfg from Trainer's loss_cfg for backward compatibility. " + "In the future, please set lm_loss_cfg directly in model_cfg instead of Trainer." + ) + def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: