From f95a00340444cfe773a74d3b18dadf3a369ae9a3 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 14:50:34 -0500 Subject: [PATCH 1/6] Expert balance --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/types.py | 1 + src/MaxText/layers/moe.py | 22 ++++++++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 2b4c338e78..1806434e76 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -178,6 +178,7 @@ megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.0 # weight for the load balance loss +expert_balance: False # whether or not to do expert balancing use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index e66d75cdca..3eec20cb56 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -553,6 +553,7 @@ class MoEGeneral(BaseModel): num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") + expert_balance: bool = Field(False, description="Whether to use expert balancing.") use_custom_sort_vjp: bool = Field( True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul." ) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index e5cf2a4d06..98220b1ee3 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -1618,6 +1618,28 @@ def dense_matmul( wo_bias, ) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]: """Dense matrix multiplication.""" + if self.config.expert_balance: + ###################################################################################################### + ############################## start hard code for uniform expert #################################### + # Create deterministic rotational pattern for gate logits + batch_size, seq_len, num_experts = gate_logits.shape + + # Create base weights for experts (increasing values) + base_weights = jnp.linspace(0.1, 0.1 * num_experts, num_experts, dtype=gate_logits.dtype) + + # Create position-based indices matrix [seq_len, num_experts] + # Each row represents which index in base_weights to use after rotation + indices = (jnp.arange(num_experts)[None, :] + jnp.arange(seq_len)[:, None]) % num_experts + + # Use advanced indexing to create the rotated weights matrix in one operation + # This takes the appropriate weight for each position based on the rotation pattern + rotated_weights = base_weights[indices] + + # Broadcast to batch dimension + gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts)) + ############################################# end #################################################### + ###################################################################################################### + # gate_logits: batch, length, expert gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): From 677710937cae12f1d8bf4aaf8b1a26c5ccd64729 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 14:54:27 -0500 Subject: [PATCH 2/6] nanoofp8 quantization --- src/MaxText/layers/moe.py | 2 +- src/MaxText/layers/quantizations.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 98220b1ee3..edaaa393b6 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -1583,7 +1583,7 @@ def get_einsum( def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument # simply skip kwargs, since aqt einsum doesn't support any kwargs # like precision - is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization) + is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) ) kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6c..c9bf025533 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -296,6 +296,8 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.NANOOFp8DotGeneralOp + def einsum(self, dtype: DType = jnp.float32): + return Fp8Einsum(dtype=dtype,e4m3_dtype=jnp.float8_e4m3fnuz,e5m2_dtype=jnp.float8_e5m2fnuz) def _get_int8_quant_config(config): drhs_bits = None From 3085abb883f5b2be8b51e31eee744e2216bf05f0 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 14:56:29 -0500 Subject: [PATCH 3/6] Device info --- src/MaxText/max_utils.py | 31 +++++++++++++++++++++++++++++++ src/MaxText/train.py | 1 + 2 files changed, 32 insertions(+) diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 510878f9be..7c19be1184 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import functools from functools import partial +import json import os import socket import subprocess @@ -705,6 +706,36 @@ def print_system_information(): max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.extend.backend.get_backend().platform_version}") + devices = jax.devices() + max_logging.log(f"System Information: Number of devices: {len(devices)}, jax path {jax.__file__}") + for i, device in enumerate(devices): + if device.local_hardware_id is not None: + max_logging.log( + f"System Information: Device {i}: {device.id} " + f"(Local id: {device.local_hardware_id}, Process index: {device.process_index})" + ) + + +def save_device_information(config): + """Convert device information to JSON format.""" + devices = jax.devices() + device_info = {'hostname': socket.gethostname(), 'devices': []} + + for device in devices: + if device.local_hardware_id is not None: + info = { + "id": device.id, + "local_hardware_id": device.local_hardware_id, + "process_index": device.process_index, + "device_kind": device.device_kind, + "platform_version": jax.extend.backend.get_backend().platform_version, + } + device_info['devices'].append(info) + # Save to JSON file + device_info_path = os.path.join(config.base_output_directory, "device_info.json") + with open(device_info_path, "w") as f: + json.dump(device_info, f, indent=4) + def permute_to_match_maxtext_rope(arr): """Permutes the Huggingface Rope to match the MaxText logic.""" diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 3fae3e056a..c323b2ea9a 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -529,6 +529,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any] config = pyconfig.initialize(argv) max_utils.print_system_information() validate_train_config(config) + max_utils.save_device_information(config) jax.config.update("jax_use_shardy_partitioner", config.shardy) # update explicit sharding-supported config if config.shard_mode == ShardMode.EXPLICIT: From 79b04ef2fc79c13546cbb83b7b7172e2c75e7880 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 14:57:52 -0500 Subject: [PATCH 4/6] Sync after step --- src/MaxText/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index c323b2ea9a..e4f3303017 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -453,6 +453,7 @@ def train_loop(config, recorder, state=None): if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, nextrng) + jax.block_until_ready(state) step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() From c25938311758478a05041b38655367e4889282b6 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 15:20:56 -0500 Subject: [PATCH 5/6] Causal mask for synthetic dataset --- src/MaxText/layers/attention_op.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index 1961ee1fff..94a0be1a73 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1399,6 +1399,10 @@ def cudnn_flash_attention( attn_mask = None dummy_attn_mask = None mask_type = "causal" + elif self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" else: # Default case: no packing, no context parallelism dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8) From c39c3428f4a740d495daf4f33e7775c8c0fd35d0 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Thu, 15 Jan 2026 15:23:19 -0500 Subject: [PATCH 6/6] Comment out unsupported context_parallel_strategy in TE --- src/MaxText/layers/attention_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index 94a0be1a73..075093bdc4 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1425,7 +1425,7 @@ def cudnn_flash_attention( window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis="context", - context_parallel_strategy=self.config.context_parallel_strategy, + # context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, )