Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ megablox: true
sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.0 # weight for the load balance loss
expert_balance: False # whether or not to do expert balancing
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ class MoEGeneral(BaseModel):
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
expert_balance: bool = Field(False, description="Whether to use expert balancing.")
use_custom_sort_vjp: bool = Field(
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
)
Expand Down
6 changes: 5 additions & 1 deletion src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,10 @@ def cudnn_flash_attention(
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
elif self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)
Expand All @@ -1421,7 +1425,7 @@ def cudnn_flash_attention(
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis="context",
context_parallel_strategy=self.config.context_parallel_strategy,
# context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
24 changes: 23 additions & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def get_einsum(
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
# simply skip kwargs, since aqt einsum doesn't support any kwargs
# like precision
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) )
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error

Expand Down Expand Up @@ -1618,6 +1618,28 @@ def dense_matmul(
wo_bias,
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
"""Dense matrix multiplication."""
if self.config.expert_balance:
######################################################################################################
############################## start hard code for uniform expert ####################################
# Create deterministic rotational pattern for gate logits
batch_size, seq_len, num_experts = gate_logits.shape

# Create base weights for experts (increasing values)
base_weights = jnp.linspace(0.1, 0.1 * num_experts, num_experts, dtype=gate_logits.dtype)

# Create position-based indices matrix [seq_len, num_experts]
# Each row represents which index in base_weights to use after rotation
indices = (jnp.arange(num_experts)[None, :] + jnp.arange(seq_len)[:, None]) % num_experts

# Use advanced indexing to create the rotated weights matrix in one operation
# This takes the appropriate weight for each position based on the rotation pattern
rotated_weights = base_weights[indices]

# Broadcast to batch dimension
gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts))
############################################# end ####################################################
######################################################################################################

# gate_logits: batch, length, expert
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
if self.config.model_name.startswith("deepseek3"):
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.NANOOFp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype,e4m3_dtype=jnp.float8_e4m3fnuz,e5m2_dtype=jnp.float8_e5m2fnuz)

def _get_int8_quant_config(config):
drhs_bits = None
Expand Down
31 changes: 31 additions & 0 deletions src/MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import functools
from functools import partial
import json
import os
import socket
import subprocess
Expand Down Expand Up @@ -705,6 +706,36 @@ def print_system_information():
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
max_logging.log(f"System Information: Jax Backend: {jax.extend.backend.get_backend().platform_version}")

devices = jax.devices()
max_logging.log(f"System Information: Number of devices: {len(devices)}, jax path {jax.__file__}")
for i, device in enumerate(devices):
if device.local_hardware_id is not None:
max_logging.log(
f"System Information: Device {i}: {device.id} "
f"(Local id: {device.local_hardware_id}, Process index: {device.process_index})"
)


def save_device_information(config):
"""Convert device information to JSON format."""
devices = jax.devices()
device_info = {'hostname': socket.gethostname(), 'devices': []}

for device in devices:
if device.local_hardware_id is not None:
info = {
"id": device.id,
"local_hardware_id": device.local_hardware_id,
"process_index": device.process_index,
"device_kind": device.device_kind,
"platform_version": jax.extend.backend.get_backend().platform_version,
}
device_info['devices'].append(info)
# Save to JSON file
device_info_path = os.path.join(config.base_output_directory, "device_info.json")
with open(device_info_path, "w") as f:
json.dump(device_info, f, indent=4)


def permute_to_match_maxtext_rope(arr):
"""Permutes the Huggingface Rope to match the MaxText logic."""
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def train_loop(config, recorder, state=None):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
jax.block_until_ready(state)

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()
Expand Down Expand Up @@ -529,6 +530,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
config = pyconfig.initialize(argv)
max_utils.print_system_information()
validate_train_config(config)
max_utils.save_device_information(config)
jax.config.update("jax_use_shardy_partitioner", config.shardy)
# update explicit sharding-supported config
if config.shard_mode == ShardMode.EXPLICIT:
Expand Down