Skip to content
4 changes: 2 additions & 2 deletions tests/engine/test_dense_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def warmup_fn(x):
seq_ctx = seq_ctx.split(sequence_parallel_mesh=sp_mesh)
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)

seq_ctx = seq_ctx_list[0]
loss_ctx = loss_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
8 changes: 4 additions & 4 deletions tests/engine/test_moe_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -184,12 +184,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
12 changes: 6 additions & 6 deletions tests/engine/test_moe_train_engine_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -165,12 +165,12 @@ def warmup_fn(x):
seq_ctx.num_padding = pack_len
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
loss_log = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down Expand Up @@ -264,12 +264,12 @@ def warmup_fn(x):
seq_ctx.to('cuda')
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx=loss_ctx)]
engine_input = [ModelItem(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})]
logs_info = engine.train_step(engine_input)["logs_info"]
grad_norm = engine.clip_grad_norm()
engine.step_optimizer(grad_norm)
Expand Down
8 changes: 4 additions & 4 deletions tests/loss/test_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_global_loss_reduction(self, loss_mode, grad_accumulation_steps, chunk_s
for data in data_batch:
seq_ctx = data["seq_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
loss_ctx_list.append(loss_ctx)
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])

Expand Down Expand Up @@ -172,7 +172,7 @@ def test_other_loss_reduction(self, loss_reduction, loss_mode, grad_accumulation
for data in data_batch:
seq_ctx = data["seq_ctx"]
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"], sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": data["shifted_labels"]}, sp_mesh=None)
loss_ctx_list.append(loss_ctx)
loss_ctx_list = CELossContext.build_batches(loss_ctx_list, cu_seq_lens_list=[seq_ctx.cu_seq_lens_q for seq_ctx in seq_ctx_list])

Expand Down Expand Up @@ -310,7 +310,7 @@ def test_sp_global_loss_reduction(self, loss_mode, sp_size, grad_accumulation_st
sp_mesh = data_mesh['sp']
seq_ctx.sequence_parallel_mesh = sp_mesh
seq_ctx_list = [seq_ctx]
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
if sp_size > 1:
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_sp_others_loss_reduction(self, loss_reduction, loss_mode, sp_size, grad
sp_mesh = data_mesh['sp']
seq_ctx.sequence_parallel_mesh = sp_mesh
seq_ctx_list = [seq_ctx]
loss_ctx = loss_cfg.build(shifted_labels=target, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": target}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
if sp_size > 1:
seq_ctx_list[0] = seq_ctx_list[0].split(sequence_parallel_mesh=sp_mesh)
Expand Down
2 changes: 1 addition & 1 deletion tests/loss/test_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
if sp_size > 1:
seq_ctx = seq_ctx.split(sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels_list_rank[iter_idx], advantages=advantages_list_rank[iter_idx], sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]}, sp_mesh=sp_mesh)
loss_ctx_list.append(loss_ctx)

with torch.no_grad():
Expand Down
3 changes: 1 addition & 2 deletions tests/loss/test_oreal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ def test_grpo_loss(self, grad_acc, sp_size, kl_loss_coef, loss_mode, chunk_size,
seq_ctx = seq_ctx.split(sp_mesh)
seq_ctx_list.append(seq_ctx)
loss_ctx = loss_cfg.build(
shifted_labels=shifted_labels_list_rank[iter_idx],
advantages=advantages_list_rank[iter_idx],
data={"shifted_labels": shifted_labels_list_rank[iter_idx], "advantages": advantages_list_rank[iter_idx]},
sp_mesh=sp_mesh,
)
loss_ctx_list.append(loss_ctx)
Expand Down
8 changes: 4 additions & 4 deletions tests/model/test_gpt_oss_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -87,7 +87,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class
with torch.no_grad():
output = gpt_oss_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
loss_cfg = CELossConfig()
seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -152,7 +152,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size):
with torch.no_grad():
output = gpt_oss_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=1e-2, rtol=1e-2))
Expand Down
16 changes: 8 additions & 8 deletions tests/model/test_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_interns1_text_run(self, device, tol):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -87,7 +87,7 @@ def test_interns1_text_run(self, device, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_interns1_image_run(self, device, sp_size, tol):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -195,7 +195,7 @@ def test_interns1_image_run(self, device, sp_size, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -256,7 +256,7 @@ def test_fsdp_text_accuracy(self, device, tol):
seq_ctx_list = [seq_ctx]
loss_cfg = CELossConfig()
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -265,7 +265,7 @@ def test_fsdp_text_accuracy(self, device, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
seq_ctx_list = [seq_ctx]
loss_cfg = CELossConfig()
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=sp_mesh)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=sp_mesh)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
Expand All @@ -379,7 +379,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol):
with torch.no_grad():
output = interns1_model(
seq_ctx=seq_ctx,
loss_ctx=loss_ctx,
loss_ctx={"lm": loss_ctx},
)
loss = output["loss"]
self.assertTrue(torch.allclose(loss, expected_loss.to(loss.dtype), atol=tol, rtol=tol))
Expand Down
10 changes: 5 additions & 5 deletions tests/model/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_moe_config(self, dtype, device):

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]
model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})


class TestDistributedMoE(DeterministicDDPTestCase):
Expand Down Expand Up @@ -135,15 +135,15 @@ def test_parallel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi

seq_ctx_list = [seq_ctx]
LossContext = loss_cfg.loss_ctx_cls
loss_ctx = loss_cfg.build(shifted_labels=shifted_labels, sp_mesh=None)
loss_ctx = loss_cfg.build(data={"shifted_labels": shifted_labels}, sp_mesh=None)
loss_ctx_list = [loss_ctx]
loss_ctx_list = LossContext.build_batches(loss_ctx_list)
loss_ctx = loss_ctx_list[0]
seq_ctx = seq_ctx_list[0]

loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
loss_parallel = parallel_model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]

loss_expected = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)["loss"]
loss_expected = model(seq_ctx=seq_ctx, loss_ctx={"lm": loss_ctx})["loss"]

torch.allclose(loss_expected, loss_parallel, atol=1e-6, rtol=1e-4)

Expand Down
Loading