Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def create_orbax_checkpoint_manager(
orbax_logger: Any = None, # pytype: disable=attribute-error
use_ocdbt: bool = True,
use_zarr3: bool = True,
max_to_keep: int = 5,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -213,6 +214,7 @@ def create_orbax_checkpoint_manager(
create=True,
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
max_to_keep = max_to_keep,
),
logger=orbax_logger,
)
Expand Down
3 changes: 2 additions & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ megablox: True
sparse_matmul: True
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
expert_balance: False # whether or not to do expert balancing
use_random_routing: False # whether to use random routing for debug/test purpose
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
Expand Down Expand Up @@ -957,4 +958,4 @@ partial_rotary_factor: 1.0

# Use tokamax library for gmm kernel implementation
use_tokamax_gmm: false
use_tokamax_splash: false
use_tokamax_splash: false
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ class MoEGeneral(BaseModel):
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.")
expert_balance: bool = Field(False, description="Whether to use expert balancing.")
use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.")
use_ring_of_experts: bool = Field(
False, description="Whether to use Ring of Experts for sparse matmul expert parallelism."
Expand Down
12 changes: 8 additions & 4 deletions src/MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def preprocessing_pipeline(
use_sft=None,
sft_train_on_completion_only=True,
grain_worker_count=1, # only support 0 or 1
max_segments_per_seq = 1, # max segments per sequence
):
"""pipeline for preprocessing HF dataset"""

Expand Down Expand Up @@ -298,10 +299,11 @@ def lists2array(x):
if packing and not use_dpo:
length_struct = {col: max_target_length for col in data_column_names}
operations.append(
grain.experimental.PackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct=length_struct,
)
grain.experimental.PackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct=length_struct,
max_sequences_per_bin=max_segments_per_seq,
)
)
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
else:
Expand Down Expand Up @@ -386,6 +388,7 @@ def make_hf_train_iterator(
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_segments_per_seq=config.max_segments_per_seq,
)
return train_iter

Expand Down Expand Up @@ -437,5 +440,6 @@ def make_hf_eval_iterator(
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_segments_per_seq=config.max_segments_per_seq,
)
return eval_iter
4 changes: 2 additions & 2 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ def cudnn_flash_attention(
dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32)
dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None)
max_segments_per_seq = self.config.max_segments_per_seq
elif using_context_parallelism:
elif using_context_parallelism or self.config.dataset_type == "synthetic":
if self.attention_type == AttentionType.LOCAL_SLIDING:
raise AssertionError("Sliding window attention is not supported for context parallelism")
# Context parallelism without packing: only supports causal masking
Expand Down Expand Up @@ -1401,7 +1401,7 @@ def cudnn_flash_attention(
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis="context",
context_parallel_strategy=self.config.context_parallel_strategy,
# context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
23 changes: 22 additions & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,7 @@ def get_einsum(
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
# simply skip kwargs, since aqt einsum doesn't support any kwargs
# like precision
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) )
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error

Expand Down Expand Up @@ -1480,6 +1480,27 @@ def dense_matmul(
wo_bias,
) -> tuple[jax.Array, Optional[jax.Array]]:
"""Dense matrix multiplication."""
if self.config.expert_balance:
######################################################################################################
############################## start hard code for uniform expert ####################################
# Create deterministic rotational pattern for gate logits
batch_size, seq_len, num_experts = gate_logits.shape

# Create base weights for experts (increasing values)
base_weights = jnp.linspace(0.1, 0.1 * num_experts, num_experts, dtype=gate_logits.dtype)

# Create position-based indices matrix [seq_len, num_experts]
# Each row represents which index in base_weights to use after rotation
indices = (jnp.arange(num_experts)[None, :] + jnp.arange(seq_len)[:, None]) % num_experts

# Use advanced indexing to create the rotated weights matrix in one operation
# This takes the appropriate weight for each position based on the rotation pattern
rotated_weights = base_weights[indices]

# Broadcast to batch dimension
gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts))
############################################# end ####################################################
######################################################################################################
# gate_logits: batch, length, expert
gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None))
if self.config.model_name.startswith("deepseek3"):
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.NANOOFp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype,e4m3_dtype=jnp.float8_e4m3fnuz,e5m2_dtype=jnp.float8_e5m2fnuz)


def _get_int8_quant_config(config):
drhs_bits = None
Expand Down
31 changes: 31 additions & 0 deletions src/MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import functools
from functools import partial
import json
import os
import socket
import subprocess
Expand Down Expand Up @@ -708,6 +709,36 @@ def print_system_information():
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
max_logging.log(f"System Information: Jax Backend: {jax.extend.backend.get_backend().platform_version}")

devices = jax.devices()
max_logging.log(f"System Information: Number of devices: {len(devices)}, jax path {jax.__file__}")
for i, device in enumerate(devices):
if device.local_hardware_id is not None:
max_logging.log(
f"System Information: Device {i}: {device.id} "
f"(Local id: {device.local_hardware_id}, Process index: {device.process_index})"
)


def save_device_information(config):
"""Convert device information to JSON format."""
devices = jax.devices()
device_info = {'hostname': socket.gethostname(), 'devices': []}

for device in devices:
if device.local_hardware_id is not None:
info = {
"id": device.id,
"local_hardware_id": device.local_hardware_id,
"process_index": device.process_index,
"device_kind": device.device_kind,
"platform_version": jax.extend.backend.get_backend().platform_version,
}
device_info['devices'].append(info)
# Save to JSON file
device_info_path = os.path.join(config.base_output_directory, "device_info.json")
with open(device_info_path, "w") as f:
json.dump(device_info, f, indent=4)


def permute_to_match_maxtext_rope(arr):
"""Permutes the Huggingface Rope to match the MaxText logic."""
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def train_loop(config, recorder, state=None):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
jax.block_until_ready(state)

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()
Expand Down Expand Up @@ -514,6 +515,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
config = pyconfig.initialize(argv)
max_utils.print_system_information()
validate_train_config(config)
max_utils.save_device_information(config)
jax.config.update("jax_use_shardy_partitioner", config.shardy)
# update explicit sharding-supported config
if config.shard_mode == ShardMode.EXPLICIT:
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_training_tools(config, model, mesh):
logger,
use_ocdbt,
use_zarr3,
config.max_num_checkpoints_to_keep,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down