Skip to content

perf(ep): optimize dispatch intranode kernel perf in cases with small tokens#333

Open
kawhil-amd wants to merge 9 commits into
mainfrom
dev/dispatch_opt
Open

perf(ep): optimize dispatch intranode kernel perf in cases with small tokens#333
kawhil-amd wants to merge 9 commits into
mainfrom
dev/dispatch_opt

Conversation

@kawhil-amd
Copy link
Copy Markdown
Contributor

@kawhil-amd kawhil-amd commented May 21, 2026

This PR introduces a warp-group optimization to the IntraNode dispatch kernel (src/ops/dispatch_combine/intranode.hpp), enabling multiple warps to cooperatively process each token-expert pair for improved memory bandwidth utilization within small token condition.

Usage

To use the IntraNodeLL kernel, specify kernel_type=IntraNodeLL when creating the config:

import mori
import torch

config = mori.ops.EpDispatchCombineConfig(
    data_type=torch.bfloat16,
    rank=rank,
    world_size=8,
    hidden_dim=7168,
    max_num_inp_token_per_rank=64,
    num_experts_per_rank=32,
    num_experts_per_token=8,
    block_num=64,
    warp_num_per_block=4,
    use_external_inp_buf=True,
    kernel_type=mori.ops.EpDispatchCombineKernelType.IntraNodeLL,
)

op = mori.ops.EpDispatchCombineOp(config)

# Dispatch and combine work the same as IntraNode
dispatch_output, ... = op.dispatch(input_tensor, weights, indices, ...)
combine_output, ... = op.combine(combine_input, weights, ...)

Dispatch Latency (MI355X)

block_num topk max_token Baseline(us) Current(us) Speedup
4 8 1 18.6 15.8 +15.08%
32 8 32 23.8 19.16 +19.4%
64 8 64 28.12 24.04 +14.6%

Test Plan

  • test_dispatch_combine_ll passes for quant_type=none
  • test_dispatch_combine_ll passes for quant_type=fp8_direct_cast
  • test_dispatch_combine_ll passes for quant_type=fp8_blockwise

kawhil-amd and others added 3 commits May 20, 2026 02:56
…rnel

- Explicitly use Unroll=2 for token data WarpCopy in dispatch kernel
  to reduce loop iterations from 14 to 7 for hiddenDim=7168
- Add __launch_bounds__(512, 1) to help compiler optimize register allocation

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@kawhil-amd kawhil-amd requested review from TianDi101 and jhchouuu May 21, 2026 08:47
@isytwu
Copy link
Copy Markdown
Collaborator

isytwu commented May 21, 2026

Maybe you could record the data before and after optimization in this PR, as well as block/warp, for 128 and 4096 tokens, respectively?

@kawhil-amd
Copy link
Copy Markdown
Contributor Author

Maybe you could record the data before and after optimization in this PR, as well as block/warp, for 128 and 4096 tokens, respectively?

Sure, will do it later.

@kawhil-amd kawhil-amd self-assigned this May 21, 2026
Comment thread src/ops/dispatch_combine/intranode.hpp Outdated
Comment thread src/ops/dispatch_combine/intranode.hpp Outdated
Comment thread src/ops/dispatch_combine/intranode.hpp Outdated
@kawhil-amd kawhil-amd requested a review from isytwu May 22, 2026 03:39
@kawhil-amd kawhil-amd changed the title perf(ep): dispatch intranode opt perf(ep): dispatch intranode opt within small token May 22, 2026
@kawhil-amd kawhil-amd changed the title perf(ep): dispatch intranode opt within small token perf(ep): optimize dispatch intranode kernel performance in cases with small tokens May 28, 2026
@kawhil-amd kawhil-amd changed the title perf(ep): optimize dispatch intranode kernel performance in cases with small tokens perf(ep): optimize dispatch intranode kernel perf in cases with small tokens May 28, 2026
@kawhil-amd kawhil-amd force-pushed the dev/dispatch_opt branch 2 times, most recently from 885948e to 1df62b1 Compare June 2, 2026 06:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants