Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
- name: MORI-EP (intranode)
run: |
$CT exec -e PYTHONPATH=$GITHUB_WORKSPACE $CONTAINER bash -c "
cd $GITHUB_WORKSPACE && timeout 300 pytest tests/python/ops/test_dispatch_combine_intranode.py -v
cd $GITHUB_WORKSPACE && timeout 360 pytest tests/python/ops/test_dispatch_combine_intranode.py -v
"

- name: MORI-EP (internode_v1)
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ jobs:
- name: MORI-EP (intranode)
run: |
$CT exec -e PYTHONPATH=$GITHUB_WORKSPACE $CONTAINER bash -c "
cd $GITHUB_WORKSPACE && timeout 300 pytest tests/python/ops/test_dispatch_combine_intranode.py -v
cd $GITHUB_WORKSPACE && timeout 360 pytest tests/python/ops/test_dispatch_combine_intranode.py -v
"

- name: MORI-EP (internode_v1)
Expand Down
15 changes: 11 additions & 4 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
namespace mori {
namespace moe {

enum KernelType { IntraNode = 0, InterNode = 1, InterNodeV1 = 2, InterNodeV1LL = 3, AsyncLL = 4 };
enum KernelType {
IntraNode = 0,
InterNode = 1,
InterNodeV1 = 2,
InterNodeV1LL = 3,
AsyncLL = 4,
IntraNodeLL = 5
};
enum class QuantType { None = 0, Fp8DirectCast = 1, Fp8BlockwiseQuant = 2 };

inline const char* HipDataTypeToString(hipDataType dtype) {
Expand Down Expand Up @@ -238,23 +245,23 @@ class EpDispatchCombineHandle {
int Fp8BlockwiseCombineScaleTypeSize() const { return fp8BlockwiseCombineScaleTypeSize; }

mori::application::SymmMemObjPtr GetShmemDispatchOutTokMemObj() const {
if (config.kernelType == KernelType::IntraNode)
if (config.kernelType == KernelType::IntraNode || config.kernelType == KernelType::IntraNodeLL)
return std::get<ShmemBufsIntraNode>(shmemTokBufs).dispatchOut;
if (config.kernelType == KernelType::InterNodeV1 ||
config.kernelType == KernelType::InterNodeV1LL)
return std::get<ShmemBufsInterNodeV1>(shmemTokBufs).dispatchOut;
return std::get<ShmemBufsInterNode>(shmemTokBufs).dispatchOut;
}
mori::application::SymmMemObjPtr GetShmemCombineOutTokMemObj() const {
if (config.kernelType == KernelType::IntraNode)
if (config.kernelType == KernelType::IntraNode || config.kernelType == KernelType::IntraNodeLL)
return std::get<ShmemBufsIntraNode>(shmemTokBufs).combineOut;
if (config.kernelType == KernelType::InterNodeV1 ||
config.kernelType == KernelType::InterNodeV1LL)
return std::get<ShmemBufsInterNodeV1>(shmemTokBufs).combineOut;
return std::get<ShmemBufsInterNode>(shmemTokBufs).combineOut;
}
mori::application::SymmMemObjPtr GetShmemCombineInpTokMemObj() const {
if (config.kernelType == KernelType::IntraNode)
if (config.kernelType == KernelType::IntraNode || config.kernelType == KernelType::IntraNodeLL)
return std::get<ShmemBufsIntraNode>(shmemTokBufs).combineInp;
if (config.kernelType == KernelType::InterNodeV1 ||
config.kernelType == KernelType::InterNodeV1LL)
Expand Down
23 changes: 20 additions & 3 deletions python/mori/ops/dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _cpp_dispatch_combine_factory(entity_name, allow_missing=False):
# ---------------------------------------------------------------------------
_KERNEL_TYPE_TO_HIP = {
EpDispatchCombineKernelType.IntraNode: "ep_intranode",
EpDispatchCombineKernelType.IntraNodeLL: "ep_intranode",
EpDispatchCombineKernelType.InterNode: "ep_internode",
EpDispatchCombineKernelType.InterNodeV1: "ep_internode_v1",
EpDispatchCombineKernelType.InterNodeV1LL: "ep_internode_v1ll",
Expand Down Expand Up @@ -573,6 +574,15 @@ def dispatch(
stream,
args_ptr,
)
elif kt == EpDispatchCombineKernelType.IntraNodeLL.value:
self._launch(
f"EpDispatchIntraNodeLLKernel_{sfx}",
grid,
block,
shared_mem,
stream,
args_ptr,
)
elif kt == EpDispatchCombineKernelType.AsyncLL.value:
mp = self._handle_info["multi_processor_count"]
mp_aligned = mp // self.config.world_size * self.config.world_size
Expand Down Expand Up @@ -753,9 +763,12 @@ def combine(
shared_mem = self._combine_shared_mem(actual_wpb)

if quant_type == EpDispatchCombineQuantType.Fp8BlockwiseQuant:
if kt != EpDispatchCombineKernelType.IntraNode.value:
if kt not in (
EpDispatchCombineKernelType.IntraNode.value,
EpDispatchCombineKernelType.IntraNodeLL.value,
):
raise ValueError(
"Fp8BlockwiseQuant currently only supports IntraNode combine"
"Fp8BlockwiseQuant currently only supports IntraNode/IntraNodeLL combine"
)
if sfx != "bf16":
raise ValueError(f"Fp8BlockwiseQuant only supports bf16, got {sfx}")
Expand Down Expand Up @@ -810,7 +823,10 @@ def combine(
stream,
args_ptr,
)
elif kt == EpDispatchCombineKernelType.IntraNode.value:
elif kt in (
EpDispatchCombineKernelType.IntraNode.value,
EpDispatchCombineKernelType.IntraNodeLL.value,
):
if quant_type == EpDispatchCombineQuantType.Fp8BlockwiseQuant:
# Mirror of the AccumNum=8 + VecBytes=8 specialization gating in
# LaunchCombine() / launch.cpp. Keep in sync.
Expand Down Expand Up @@ -1366,6 +1382,7 @@ def get_dispatch_src_token_pos(self):

if self.config.kernel_type.value in (
EpDispatchCombineKernelType.IntraNode.value,
EpDispatchCombineKernelType.IntraNodeLL.value,
EpDispatchCombineKernelType.InterNodeV1.value,
EpDispatchCombineKernelType.InterNodeV1LL.value,
EpDispatchCombineKernelType.AsyncLL.value,
Expand Down
2 changes: 1 addition & 1 deletion python/mori/ops/tuning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
CONFIG_STR_TO_SHORT_NAME: dict[str, str] = {r[1]: r[2] for r in _DTYPE_REGISTRY}

_KERNEL_TYPE_NAMES = frozenset(
{"IntraNode", "InterNode", "InterNodeV1", "InterNodeV1LL", "AsyncLL"}
{"IntraNode", "InterNode", "InterNodeV1", "InterNodeV1LL", "AsyncLL", "IntraNodeLL"}
)

_QUANT_TYPE_CONFIG_STRS = {"none", "fp8_direct_cast", "fp8_blockwise"}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"version": "1.0",
"gpu_arch": "gfx950",
"gpu_model": "mi355x",
"kernel_type": "IntraNodeLL",
"ep_size": 8,
"phase": "dispatch",
"rules": [
{
"dtype": "bf16",
"num_tokens": 1,
"hidden_dim": 7168,
"block_num": 4,
"rdma_block_num": 0,
"warp_per_block": 4,
"bandwidth_gbps": 2.56,
"latency_us": 27.7
},
{
"dtype": "bf16",
"num_tokens": 32,
"hidden_dim": 7168,
"block_num": 32,
"rdma_block_num": 0,
"warp_per_block": 16,
"bandwidth_gbps": 79.9,
"latency_us": 30.9
},
{
"dtype": "bf16",
"num_tokens": 64,
"hidden_dim": 7168,
"block_num": 64,
"rdma_block_num": 0,
"warp_per_block": 16,
"bandwidth_gbps": 145.95,
"latency_us": 34.3
}
]
}
6 changes: 4 additions & 2 deletions src/ops/dispatch_combine/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ template <typename T>
__device__ inline void InvokeConvertDispatchOutput(const EpDispatchCombineArgs<T>& args, int myPe) {
ConvertDispatchOutputArgs convArgs{};
convArgs.config = args.config;
if (args.config.kernelType == KernelType::IntraNode) {
if (args.config.kernelType == KernelType::IntraNode ||
args.config.kernelType == KernelType::IntraNodeLL) {
convArgs.dispatchOutX = args.intraNodeTokBufs.dispatchOut->template GetAs<T*>(myPe);
} else if (args.config.kernelType == KernelType::InterNodeV1 ||
args.config.kernelType == KernelType::InterNodeV1LL) {
Expand Down Expand Up @@ -218,7 +219,8 @@ __device__ inline void InvokeConvertCombineInput(const EpDispatchCombineArgs<T>&
convArgs.combineInput = nullptr;
convArgs.dispTokToEpSlotMap = args.dispTokToEpSlotMap;
convArgs.packedRecvCount = args.standardPackedRecvCount;
if (args.config.kernelType == KernelType::IntraNode) {
if (args.config.kernelType == KernelType::IntraNode ||
args.config.kernelType == KernelType::IntraNodeLL) {
convArgs.shmemCombineInpTokMemObj = args.intraNodeTokBufs.combineInp;
} else if (args.config.kernelType == KernelType::InterNodeV1 ||
args.config.kernelType == KernelType::InterNodeV1LL) {
Expand Down
7 changes: 4 additions & 3 deletions src/ops/dispatch_combine/dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
config.WeightBytes() + config.SrcTokenIdBytes() + blockwiseScaleBytes);
}

if (config.kernelType == KernelType::IntraNode) {
if (config.kernelType == KernelType::IntraNode || config.kernelType == KernelType::IntraNodeLL) {
auto& bufs = shmemTokBufs.emplace<ShmemBufsIntraNode>();
bufs.combineInp = ShmemMallocAndReturnMemObjPtr(maxStagingSize, hipDeviceMallocUncached);
bufs.dispatchOut = ShmemMallocAndReturnMemObjPtr(dispatchOutSize, hipDeviceMallocUncached);
Expand Down Expand Up @@ -271,7 +271,7 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
}

void EpDispatchCombineHandle::FinalizeShmemBuf() {
if (config.kernelType == KernelType::IntraNode) {
if (config.kernelType == KernelType::IntraNode || config.kernelType == KernelType::IntraNodeLL) {
auto& bufs = std::get<ShmemBufsIntraNode>(shmemTokBufs);
ShmemFree(bufs.dispatchOut->localPtr);
ShmemFree(bufs.combineInp->localPtr);
Expand Down Expand Up @@ -463,7 +463,8 @@ EpDispatchCombineArgsRaw GetEpDispatchCombineArgsRaw(const EpDispatchCombineHand
args.scalesBuf = handle.scalesBuf;
args.destPeTokenCounter = handle.destPeTokenCounter;
args.localPeTokenCounter = handle.localPeTokenCounter;
if (handle.config.kernelType == KernelType::IntraNode) {
if (handle.config.kernelType == KernelType::IntraNode ||
handle.config.kernelType == KernelType::IntraNodeLL) {
args.intraNodeTokBufs = std::get<ShmemBufsIntraNode>(handle.shmemTokBufs);
} else if (handle.config.kernelType == KernelType::InterNodeV1 ||
handle.config.kernelType == KernelType::InterNodeV1LL) {
Expand Down
Loading
Loading