Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fe9a66c
[ROCm] port tdm to npi_gfx1250
wangye805 Apr 22, 2026
40f3902
[ROCm] address first round reviewer comments
wangye805 Apr 23, 2026
cb421f8
[ROCm] address reviewer comments
wangye805 Apr 23, 2026
0e2d24d
[ROCm] address more reviewer comments
wangye805 Apr 23, 2026
0a13bbc
[ROCm] Address TDM review comments: remove extra params, add explanat…
wangye805 Apr 23, 2026
bfb7199
[ROCm] Address remaining review comments and enable TDM flow in CI gt…
wangye805 Apr 23, 2026
ba4bbb7
tdm: clamp tensorDim to avoid uint32_t underflow on OOB prefetch tiles
wangye805 Apr 24, 2026
ab77fbf
tdm: add HIPTensorMap descriptor struct; revert TDM from rocm_* kernels
wangye805 Apr 25, 2026
a0a60fe
tdm: fully revert rocm_*.cuh to branch-point state
wangye805 Apr 25, 2026
506d78c
tdm: extract ROCm flow into separate rocm_* launchers; TDM stays in m…
wangye805 Apr 25, 2026
acc7e4f
tdm: address review comments for cast_gated_kernels.cuh
wangye805 Apr 26, 2026
3c85101
tdm: revert swizzled_* lines to NV upstream position
wangye805 Apr 26, 2026
0007c88
tdm: fix cast_mxfp8_gated to match NV upstream structure
wangye805 Apr 26, 2026
bacd226
tdm: use switch(scaling_type) for AMD TDM mxfp8 gated dispatch
wangye805 Apr 26, 2026
4ba5883
tdm: hoist shared next-stage offset vars above #ifdef in cast_mxfp8_g…
wangye805 Apr 26, 2026
7dbf218
tdm: hoist shared shmem computation above #ifdef in cast_fp8_gated
wangye805 Apr 26, 2026
293d970
tdm: collapse duplicate switch(scaling_type) blocks in cast_mxfp8_gated
wangye805 Apr 26, 2026
53b5d22
tdm: remove tma_flow namespace; prefix ROCm-specific constants with R…
wangye805 Apr 26, 2026
a0a9ab6
util: apply ROCM_ prefix to ROCm-specific constants in cast/dequantiz…
wangye805 Apr 26, 2026
fdb4b1a
util: address PR review comments on cast_kernels.cuh and dequantize_k…
wangye805 Apr 26, 2026
a732d35
util: address 4 more PR review comments on cast/dequantize kernels
wangye805 Apr 26, 2026
1a004df
util: hoist shared next-iter offset vars above #ifndef in cast_mxfp8_…
wangye805 Apr 26, 2026
573f8d7
Revert " Remove padding from scales for hipBLASlt calls (#442)"
wangye805 Apr 26, 2026
fec2de5
fix(rocm): correct double-prefixed constants in rocm_cast_gated_kerne…
wangye805 Apr 26, 2026
004d59f
fix(rocm): add TMA_SHMEM_ALIGNMENT alias and sigmoidf for AMD compila…
wangye805 Apr 26, 2026
7c86c98
fix(rocm): fix fp8_quantize AMD flow — remove unavailable fp8_quantiz…
wangye805 Apr 26, 2026
0b40533
fix(rocm): route NVTE_MXFP8_1D_SCALING through fp8_quantize_rocm on AMD
wangye805 Apr 27, 2026
c89b5ff
fix(rocm): use padded scales_stride in rocm_mxfp8_dequantize
wangye805 Apr 27, 2026
9f55a8b
fix(rocm): guard TDM flow dispatch behind __gfx1250__ on AMD
wangye805 Apr 27, 2026
8338725
fix(rocm): wire up cast_mxfp8_2D_kernel launch on gfx1250 TDM path
wangye805 Apr 27, 2026
14329d5
refactor(rocm): consolidate mxfp8_quantize kernel launch for TDM and TMA
wangye805 Apr 27, 2026
d38c6bd
fix(amd): guard cudaFuncSetAttribute and add hip_bfloat16 overloads f…
wangye805 Apr 27, 2026
14a1dab
chore: remove debug print statements from MXFP8 cast/dequantize kernels
wangye805 Apr 29, 2026
198495a
chore: restore launcher debug prints, remove only in-kernel printf st…
wangye805 Apr 29, 2026
362ae53
test: add 16384x16384 matrix size to CastMXFP8_GatedAct benchmark run
wangye805 Apr 29, 2026
0456492
feat: migrate benchmarks/cpp/cast from dev branch
wangye805 Apr 29, 2026
573f6ea
build: add rocm_utils.cmake needed by benchmarks/cpp CMakeLists
wangye805 Apr 29, 2026
9f340a1
fix: suppress clang warnings in Google Benchmark for gfx1250 toolchain
wangye805 Apr 29, 2026
09ed78c
test: remove 16384x16384 from gated swiglu test (causes CPU ref hang)
wangye805 Apr 29, 2026
b02fe76
fix(rocm): remove TDM debug prints and fix NVTE_ROCM_BENCHMARK guards…
wangye805 Apr 30, 2026
186d793
fix(rocm): restore indentation lost during debug print removal
wangye805 Apr 30, 2026
9862745
fix(rocm): remove segfault debug handler and restore indentation in b…
wangye805 Apr 30, 2026
1afc5b2
fix(rocm): restore indentation in fp8_quantize_rocm TDM branch
wangye805 Apr 30, 2026
5da5014
Remove leftover debug printf from tdm.cuh copy_2d_to_shared
wangye805 Apr 30, 2026
65e92ab
Fix TDM double-buffer: use wait_tensorcnt_1 after store to preserve o…
wangye805 Apr 30, 2026
8ea1cbd
Make TDM store wait robust via wait_tensorcnt<N>() template
wangye805 Apr 30, 2026
882bea2
cast_gated_kernels: use wait_tensorcnt<TDM_PREFETCH_LOADS>() template
wangye805 Apr 30, 2026
fcf1932
tdm: rename is_tdm_wave() to is_tdm_lane(), restrict to thread 0
wangye805 Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 0 additions & 115 deletions transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip
Comment thread
wangye805 marked this conversation as resolved.

This file was deleted.

56 changes: 54 additions & 2 deletions transformer_engine/common/util/rocm_cast_gated_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "math.h"
#include "ptx.cuh"
#include "rocm_vectorized_2d.cuh"
#include "tdm.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "vectorized_pointwise.h"
Expand Down Expand Up @@ -134,6 +135,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t row_base = chunk_it_offset_y;

// Initiate bulk tensor copy
#if defined(__gfx1250__)
Comment thread
wangye805 marked this conversation as resolved.
Outdated
{
constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
if constexpr (IS_DGATED) {
// grad uses stride=cols, act/gate use stride=2*cols -- issue separately
tdm::copy_2d_to_shared(
&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz);
tdm::copy_2d_to_shared_x2(
&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz);
} else {
tdm::copy_2d_to_shared_x2(
&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
if constexpr (IS_DGATED) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -142,12 +165,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Act
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);

// Gate
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);

__syncthreads();
#endif

const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
Expand Down Expand Up @@ -353,6 +377,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)

__syncthreads();

#if defined(__gfx1250__)
{
constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);
if constexpr (USE_ROWWISE_SCALING) {
tdm::store_2d_to_global(&out_act_rowwise_sh[0], output_act_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
if constexpr (IS_DGATED) {
tdm::store_2d_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
}
}
if constexpr (USE_COLWISE_SCALING) {
tdm::store_2d_to_global(&out_act_colwise_sh[0], output_act_colwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
if constexpr (IS_DGATED) {
tdm::store_2d_to_global(&out_gate_colwise_sh[0], output_gate_colwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
}
}
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
if constexpr (USE_ROWWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -361,7 +412,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
}
}

if constexpr (USE_COLWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -371,6 +422,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
__syncthreads();
#endif
}
}
} // namespace gated_kernels
Expand Down
41 changes: 38 additions & 3 deletions transformer_engine/common/util/rocm_cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "math.h"
#include "ptx.cuh"
#include "rocm_vectorized_2d.cuh"
#include "tdm.cuh"
#include "transformer_engine/cast.h"
#include "../transpose/cast_transpose.h"
#include "vectorized_pointwise.h"
Expand Down Expand Up @@ -161,15 +162,31 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t row_base = chunk_it_offset_y;
#if defined(__gfx1250__)
constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
if constexpr (IS_DACT) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&act_in_sh[0][0], act_input_ptr,
chunk_it_offset_x, chunk_it_offset_y, cols,
tdm::copy_2d_to_shared_x2(
&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y,
&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz);
} else {
tdm::copy_2d_to_shared(
&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
#else
if constexpr (IS_DACT) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&act_in_sh[0][0], act_input_ptr,
chunk_it_offset_x, chunk_it_offset_y, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols);
}
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, rows, cols);
__syncthreads();
#endif

if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
Expand Down Expand Up @@ -312,6 +329,23 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)

__syncthreads();

#if defined(__gfx1250__)
constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);
if constexpr (USE_ROWWISE_SCALING) {
tdm::store_2d_to_global(&out_rowwise_sh[0][0], output_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y,
cols, rows, cols, out_data_sz);
}
if constexpr (USE_COLWISE_SCALING) {
tdm::store_2d_to_global(&out_colwise_sh[0][0], output_colwise,
chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y,
cols, rows, cols, out_data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
#else
if constexpr (USE_ROWWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
Expand All @@ -324,6 +358,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
}

__syncthreads();
#endif
}
}

Expand Down
25 changes: 24 additions & 1 deletion transformer_engine/common/util/rocm_dequantize_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "math.h"
#include "ptx.cuh"
#include "rocm_vectorized_2d.cuh"
#include "tdm.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "../transpose/cast_transpose.h"
Expand Down Expand Up @@ -85,10 +86,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;

copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
#if defined(__gfx1250__)
{
constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
tdm::copy_2d_to_shared(&in_sh[0][0], input_ptr,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz);
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, rows, cols);
__syncthreads();
#endif

const int scale_offset_Y =
USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y)
Expand Down Expand Up @@ -126,11 +138,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)

__syncthreads();

#if defined(__gfx1250__)
{
constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);
tdm::store_2d_to_global(&out_sh[0][0], output_ptr,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, out_data_sz);
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_sh[0][0], output_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, rows, cols);

__syncthreads();
#endif
}
}
} // namespace dequantization
Expand Down
Loading