Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fe9a66c
[ROCm] port tdm to npi_gfx1250
wangye805 Apr 22, 2026
40f3902
[ROCm] address first round reviewer comments
wangye805 Apr 23, 2026
cb421f8
[ROCm] address reviewer comments
wangye805 Apr 23, 2026
0e2d24d
[ROCm] address more reviewer comments
wangye805 Apr 23, 2026
0a13bbc
[ROCm] Address TDM review comments: remove extra params, add explanat…
wangye805 Apr 23, 2026
bfb7199
[ROCm] Address remaining review comments and enable TDM flow in CI gt…
wangye805 Apr 23, 2026
ba4bbb7
tdm: clamp tensorDim to avoid uint32_t underflow on OOB prefetch tiles
wangye805 Apr 24, 2026
ab77fbf
tdm: add HIPTensorMap descriptor struct; revert TDM from rocm_* kernels
wangye805 Apr 25, 2026
a0a60fe
tdm: fully revert rocm_*.cuh to branch-point state
wangye805 Apr 25, 2026
506d78c
tdm: extract ROCm flow into separate rocm_* launchers; TDM stays in m…
wangye805 Apr 25, 2026
acc7e4f
tdm: address review comments for cast_gated_kernels.cuh
wangye805 Apr 26, 2026
3c85101
tdm: revert swizzled_* lines to NV upstream position
wangye805 Apr 26, 2026
0007c88
tdm: fix cast_mxfp8_gated to match NV upstream structure
wangye805 Apr 26, 2026
bacd226
tdm: use switch(scaling_type) for AMD TDM mxfp8 gated dispatch
wangye805 Apr 26, 2026
4ba5883
tdm: hoist shared next-stage offset vars above #ifdef in cast_mxfp8_g…
wangye805 Apr 26, 2026
7dbf218
tdm: hoist shared shmem computation above #ifdef in cast_fp8_gated
wangye805 Apr 26, 2026
293d970
tdm: collapse duplicate switch(scaling_type) blocks in cast_mxfp8_gated
wangye805 Apr 26, 2026
53b5d22
tdm: remove tma_flow namespace; prefix ROCm-specific constants with R…
wangye805 Apr 26, 2026
a0a9ab6
util: apply ROCM_ prefix to ROCm-specific constants in cast/dequantiz…
wangye805 Apr 26, 2026
fdb4b1a
util: address PR review comments on cast_kernels.cuh and dequantize_k…
wangye805 Apr 26, 2026
a732d35
util: address 4 more PR review comments on cast/dequantize kernels
wangye805 Apr 26, 2026
1a004df
util: hoist shared next-iter offset vars above #ifndef in cast_mxfp8_…
wangye805 Apr 26, 2026
573f8d7
Revert " Remove padding from scales for hipBLASlt calls (#442)"
wangye805 Apr 26, 2026
fec2de5
fix(rocm): correct double-prefixed constants in rocm_cast_gated_kerne…
wangye805 Apr 26, 2026
004d59f
fix(rocm): add TMA_SHMEM_ALIGNMENT alias and sigmoidf for AMD compila…
wangye805 Apr 26, 2026
7c86c98
fix(rocm): fix fp8_quantize AMD flow — remove unavailable fp8_quantiz…
wangye805 Apr 26, 2026
0b40533
fix(rocm): route NVTE_MXFP8_1D_SCALING through fp8_quantize_rocm on AMD
wangye805 Apr 27, 2026
c89b5ff
fix(rocm): use padded scales_stride in rocm_mxfp8_dequantize
wangye805 Apr 27, 2026
9f55a8b
fix(rocm): guard TDM flow dispatch behind __gfx1250__ on AMD
wangye805 Apr 27, 2026
8338725
fix(rocm): wire up cast_mxfp8_2D_kernel launch on gfx1250 TDM path
wangye805 Apr 27, 2026
14329d5
refactor(rocm): consolidate mxfp8_quantize kernel launch for TDM and TMA
wangye805 Apr 27, 2026
d38c6bd
fix(amd): guard cudaFuncSetAttribute and add hip_bfloat16 overloads f…
wangye805 Apr 27, 2026
14a1dab
chore: remove debug print statements from MXFP8 cast/dequantize kernels
wangye805 Apr 29, 2026
198495a
chore: restore launcher debug prints, remove only in-kernel printf st…
wangye805 Apr 29, 2026
362ae53
test: add 16384x16384 matrix size to CastMXFP8_GatedAct benchmark run
wangye805 Apr 29, 2026
0456492
feat: migrate benchmarks/cpp/cast from dev branch
wangye805 Apr 29, 2026
573f6ea
build: add rocm_utils.cmake needed by benchmarks/cpp CMakeLists
wangye805 Apr 29, 2026
9f340a1
fix: suppress clang warnings in Google Benchmark for gfx1250 toolchain
wangye805 Apr 29, 2026
09ed78c
test: remove 16384x16384 from gated swiglu test (causes CPU ref hang)
wangye805 Apr 29, 2026
b02fe76
fix(rocm): remove TDM debug prints and fix NVTE_ROCM_BENCHMARK guards…
wangye805 Apr 30, 2026
186d793
fix(rocm): restore indentation lost during debug print removal
wangye805 Apr 30, 2026
9862745
fix(rocm): remove segfault debug handler and restore indentation in b…
wangye805 Apr 30, 2026
1afc5b2
fix(rocm): restore indentation in fp8_quantize_rocm TDM branch
wangye805 Apr 30, 2026
5da5014
Remove leftover debug printf from tdm.cuh copy_2d_to_shared
wangye805 Apr 30, 2026
65e92ab
Fix TDM double-buffer: use wait_tensorcnt_1 after store to preserve o…
wangye805 Apr 30, 2026
8ea1cbd
Make TDM store wait robust via wait_tensorcnt<N>() template
wangye805 Apr 30, 2026
882bea2
cast_gated_kernels: use wait_tensorcnt<TDM_PREFETCH_LOADS>() template
wangye805 Apr 30, 2026
fcf1932
tdm: rename is_tdm_wave() to is_tdm_lane(), restrict to thread 0
wangye805 Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
554 changes: 465 additions & 89 deletions transformer_engine/common/util/cast_gated_kernels.cuh

Large diffs are not rendered by default.

374 changes: 342 additions & 32 deletions transformer_engine/common/util/cast_kernels.cuh

Large diffs are not rendered by default.

94 changes: 85 additions & 9 deletions transformer_engine/common/util/dequantize_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,26 @@
#include "transformer_engine/transpose.h"
#ifdef __HIP_PLATFORM_AMD__
#include "rocm_dequantize_kernels.cuh"
#include "tdm.cuh"
#endif

namespace transformer_engine {

namespace dequantization {

#ifndef __HIP_PLATFORM_AMD__
template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
dequantize_mxfp8_kernel(
#ifdef __HIP_PLATFORM_AMD__
const IType *__restrict__ input_ptr,
OType *__restrict__ output_ptr,
#else
const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
#endif
const e8m0_t *const scales_ptr, const size_t rows, const size_t cols,
const size_t scales_stride) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;

constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
Expand Down Expand Up @@ -75,15 +81,23 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;

// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
#ifdef __HIP_PLATFORM_AMD__
alignas(128) __shared__ IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
alignas(128) __shared__ OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
#else
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
#endif

constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
#ifndef __HIP_PLATFORM_AMD__
constexpr int transaction_size = shmem_buff_size;
#endif

const bool is_master_thread = (threadIdx.x == 0);

#ifndef __HIP_PLATFORM_AMD__
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
Expand Down Expand Up @@ -118,12 +132,25 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[iteration_zero]);
}
#else // __HIP_PLATFORM_AMD__ — TDM prefetch
constexpr uint32_t deq_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
constexpr uint32_t deq_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);

// Prefetch first iteration
tdm::copy_2d_to_shared(&in_sh[0][0][0], input_ptr,
chunk_offset_X, chunk_offset_Y,
SHMEM_DIM_X, SHMEM_DIM_Y,
cols, rows, cols, deq_in_data_sz);
tdm::wait_tensorcnt_0();
__syncthreads();
#endif // __HIP_PLATFORM_AMD__

#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % BUFFERS_NUM;
const int next_iter = iter + 1;
if (next_iter < ITERATIONS) {
#ifndef __HIP_PLATFORM_AMD__
if (is_master_thread) {
const int next_buff = next_iter % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y;
Expand All @@ -140,12 +167,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[next_iter]);
}
#else // __HIP_PLATFORM_AMD__ — TDM prefetch next iteration
{
const int next_buff = next_iter % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], input_ptr,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y,
cols, rows, cols, deq_in_data_sz);
}
#endif // __HIP_PLATFORM_AMD__
}

#ifndef __HIP_PLATFORM_AMD__
ptx::fence_proxy_async_shared_cta();

// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#else
tdm::wait_tensorcnt_0();
__syncthreads();
#endif

const int scale_offset_Y =
USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y)
Expand Down Expand Up @@ -181,6 +224,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}

#ifndef __HIP_PLATFORM_AMD__
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
Expand All @@ -200,7 +244,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<1>();
}
#else // __HIP_PLATFORM_AMD__ — TDM store
__syncthreads();
{
const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
tdm::store_2d_to_global(&out_sh[buff][0][0], output_ptr,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y,
cols, rows, cols, deq_out_data_sz);
tdm::wait_tensorcnt_0();
__syncthreads();
}
#endif // __HIP_PLATFORM_AMD__
}

#ifndef __HIP_PLATFORM_AMD__
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();

Expand All @@ -215,9 +274,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mbarrier_invalid(&mbar[iter]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#else
tdm::wait_tensorcnt_0();
#endif
#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
}
#endif // #ifndef __HIP_PLATFORM_AMD__

static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
Expand Down Expand Up @@ -312,9 +373,24 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
#ifdef __HIP_PLATFORM_AMD__
TRANSFORMER_ENGINE_SWITCH_CONDITION(
!(cols % (32 * sizeof(OType))), IS_ALIGNED,
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X, IS_ALIGNED>
<<<grid, block, 0, stream>>>(reinterpret_cast<const IType *>(input_data.dptr), reinterpret_cast<OType *>(output->data.dptr), scales_ptr,
rows, cols, scales_stride);); // NOLINT(*)
{
const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW");
if (env && std::string(env) == "1") {
// NV upstream kernel with TDM
dequantization::dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(
reinterpret_cast<const IType *>(input_data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
scales_ptr, rows, cols, scales_stride);
} else {
// Default ROCm flow
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X, IS_ALIGNED>
<<<grid, block, 0, stream>>>(
reinterpret_cast<const IType *>(input_data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
scales_ptr, rows, cols, scales_stride);
}
}); // NOLINT(*)
#else // #ifdef __HIP_PLATFORM_AMD__
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
Expand Down
56 changes: 54 additions & 2 deletions transformer_engine/common/util/rocm_cast_gated_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "math.h"
#include "ptx.cuh"
#include "rocm_vectorized_2d.cuh"
#include "tdm.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "vectorized_pointwise.h"
Expand Down Expand Up @@ -134,6 +135,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t row_base = chunk_it_offset_y;

// Initiate bulk tensor copy
#if defined(__gfx1250__)
Comment thread
wangye805 marked this conversation as resolved.
Outdated
{
constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
if constexpr (IS_DGATED) {
// grad uses stride=cols, act/gate use stride=2*cols -- issue separately
tdm::copy_2d_to_shared(
&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz);
tdm::copy_2d_to_shared_x2(
&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz);
} else {
tdm::copy_2d_to_shared_x2(
&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
if constexpr (IS_DGATED) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -142,12 +165,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Act
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);

// Gate
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);

__syncthreads();
#endif

const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
Expand Down Expand Up @@ -353,6 +377,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)

__syncthreads();

#if defined(__gfx1250__)
{
constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);
if constexpr (USE_ROWWISE_SCALING) {
tdm::store_2d_to_global(&out_act_rowwise_sh[0], output_act_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
if constexpr (IS_DGATED) {
tdm::store_2d_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
}
}
if constexpr (USE_COLWISE_SCALING) {
tdm::store_2d_to_global(&out_act_colwise_sh[0], output_act_colwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
if constexpr (IS_DGATED) {
tdm::store_2d_to_global(&out_gate_colwise_sh[0], output_gate_colwise,
chunk_it_offset_x, chunk_it_offset_y,
SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz);
}
}
tdm::wait_tensorcnt_0();
__syncthreads();
}
#else
if constexpr (USE_ROWWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -361,7 +412,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
}
}

if constexpr (USE_COLWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
Expand All @@ -371,6 +422,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
}
}
__syncthreads();
#endif
}
}
} // namespace gated_kernels
Expand Down
41 changes: 38 additions & 3 deletions transformer_engine/common/util/rocm_cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "math.h"
#include "ptx.cuh"
#include "rocm_vectorized_2d.cuh"
#include "tdm.cuh"
#include "transformer_engine/cast.h"
#include "../transpose/cast_transpose.h"
#include "vectorized_pointwise.h"
Expand Down Expand Up @@ -161,15 +162,31 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
const size_t row_base = chunk_it_offset_y;
#if defined(__gfx1250__)
constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8);
if constexpr (IS_DACT) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&act_in_sh[0][0], act_input_ptr,
chunk_it_offset_x, chunk_it_offset_y, cols,
tdm::copy_2d_to_shared_x2(
&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y,
&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz);
} else {
tdm::copy_2d_to_shared(
&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
#else
if constexpr (IS_DACT) {
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&act_in_sh[0][0], act_input_ptr,
chunk_it_offset_x, chunk_it_offset_y, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols);
}
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, rows, cols);
__syncthreads();
#endif

if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
Expand Down Expand Up @@ -312,6 +329,23 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)

__syncthreads();

#if defined(__gfx1250__)
constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8);
if constexpr (USE_ROWWISE_SCALING) {
tdm::store_2d_to_global(&out_rowwise_sh[0][0], output_rowwise,
chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y,
cols, rows, cols, out_data_sz);
}
if constexpr (USE_COLWISE_SCALING) {
tdm::store_2d_to_global(&out_colwise_sh[0][0], output_colwise,
chunk_it_offset_x, chunk_it_offset_y,
MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y,
cols, rows, cols, out_data_sz);
}
tdm::wait_tensorcnt_0();
__syncthreads();
#else
if constexpr (USE_ROWWISE_SCALING) {
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
Expand All @@ -324,6 +358,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
}

__syncthreads();
#endif
}
}

Expand Down
Loading