Skip to content
Open
122 changes: 122 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,6 +3078,128 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN"])
Comment thread
aris134 marked this conversation as resolved.
Outdated
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- N: must be multiple of 16 (GetVectorSizeC, no dword fallback), tile 128/256
- K tile: 64, M tile: 256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2026
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"

if layout == "TN":
# TN GEMM: M=m_splits[i], N=A.rows, K=A.cols
if pad_dim == "K":
k_val = unaligned_k
m_vals = [m_aligned] * z
n_val = n_aligned
elif pad_dim == "M":
k_val = k_aligned
m_vals = unaligned_m
n_val = n_aligned
else: # N
k_val = k_aligned
m_vals = [m_aligned] * z
n_val = unaligned_n

Comment thread
aris134 marked this conversation as resolved.
Outdated
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
else: # NN
# NN GEMM: M=m_splits[i], N=A.cols, K=A.rows
if pad_dim == "K":
gemm_k = unaligned_k
Comment thread
aris134 marked this conversation as resolved.
Outdated
m_vals = [m_aligned] * z
n_out = n_aligned
elif pad_dim == "M":
gemm_k = k_aligned
m_vals = unaligned_m
n_out = n_aligned
else: # N
gemm_k = k_aligned
m_vals = [m_aligned] * z
n_out = unaligned_n

A = [torch.randn(gemm_k, n_out, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, gemm_k, dtype=dtype, device="cuda") for m in m_vals]
total_m = sum(m_vals)
out = [torch.randn(total_m, n_out, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True

# Reference: individual GEMMs
for i in range(z):
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)

@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,17 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even), "
"N must be multiple of 16 (GetVectorSizeC).");
Comment thread
aris134 marked this conversation as resolved.
Outdated
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ struct TileCfg_256x128x64 : TileCfg_256x256x64 {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
static constexpr bool kPadN = true;
template <typename Base, bool PadM_, bool PadN_, bool PadK_>
struct WithPadding : Base {
static constexpr bool kPadM = PadM_;
static constexpr bool kPadN = PadN_;
static constexpr bool kPadK = PadK_;
};

template <typename AType,
Expand Down Expand Up @@ -196,15 +199,15 @@ class GroupedGemmRunner : public RunnerInterface {
}
};

#define MAKE_RUNNER(TileCfg_) \
#define MAKE_RUNNER(BaseCfg_, kPadM_, kPadN_, kPadK_) \
TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \
using Runner = GroupedGemmRunner<AType, \
BType, \
CType, \
ALayout, \
BLayout, \
CLayout, \
TileCfg_, \
WithPadding<BaseCfg_, kPadM_, kPadN_, kPadK_>, \
accum_option>; \
runner = std::make_unique<Runner>(); \
})
Expand All @@ -216,6 +219,37 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
const ck_tile::stream_config s{ctx.stream};
std::unique_ptr<RunnerInterface> runner = nullptr;

// Check M and K alignment across all groups.
// All tile configs share the same M_Tile (256) and K_Tile (64).
constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile;
constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile;

bool need_m_pad = false;
bool need_k_pad = false;

for (int i = 0; i < ctx.group_num; ++i) {
const transformer_engine::Tensor* A_te =
transformer_engine::convertNVTETensorCheck(ctx.A[i]);
int64_t Ad0 = 0, Ad1 = 0;
if (get_flat_2d_dims(*A_te, Ad0, Ad1)) {
const int64_t M = ctx.transA ? Ad1 : Ad0;
const int64_t K = ctx.transA ? Ad0 : Ad1;

if (M % M_Tile != 0)
need_m_pad = true;
if (K % K_Tile != 0)
need_k_pad = true;
if (need_m_pad && need_k_pad)
break;
}
}

// CK tile kernel produces incorrect results with kPadK + ColMajor B.
// Fall back to cuBLAS for this combination.
if (need_k_pad && ctx.transB) {
return false;
}
Comment thread
aris134 marked this conversation as resolved.

TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

Expand All @@ -230,13 +264,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, {
using CType = typename TETypeToCKType<d_te_type>::type;

if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64);
} else {
MAKE_RUNNER(TileCfg_256x128x64_padding);
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, {
if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK);
} else {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK);
}
});
});
});
});
});
Expand Down
Loading