Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,144 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


if IS_HIP_EXTENSION:
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT", "TT"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N", "MK", "MKN"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim, capfd):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- K tile: 64, M tile: 256, N tile: 128/256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2016
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think z should be derived as len of unaligned_m, or it should be asserted that they are equal

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in dff5635

unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

# Select (un)aligned values based on pad_dim.
k_val = unaligned_k if "K" in pad_dim else k_aligned
m_vals = unaligned_m if "M" in pad_dim else [m_aligned] * z
n_val = unaligned_n if "N" in pad_dim else n_aligned

total_m = sum(m_vals)
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better use monkeypath to make sure the envs are cleared if tests fails

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in dff5635

os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1"

if layout == "TN":
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(k_val, n_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True
elif layout == "NT":
A = list(torch.split(
torch.randn(total_m, k_val, dtype=dtype, device="cuda"), m_vals
))
B = list(torch.split(
torch.randn(total_m, n_val, dtype=dtype, device="cuda"), m_vals
))
out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
out_ref = [o.clone() for o in out]
m_splits = m_vals
grad = True
single_output = False
else: # TT
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(k_val, m, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True

# Reference: individual GEMMs
for i in range(z):
if layout == "TT":
# general_gemm doesn't support TT; compute reference manually.
ref = B[i].T.to(torch.float32) @ A[i].T.to(torch.float32)
if accumulate:
out_ref[i] = (out_ref[i].to(torch.float32) + ref).to(dtype)
else:
out_ref[i] = ref.to(dtype)
else:
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)

if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test itself is IS_HIP_EXTENSION only

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed the extra if in dff5635.

torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
os.environ.pop("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", None)

# Check for CK fallback warnings from C++ (NVTE_WARN writes to std::cerr).
# capfd captures file-descriptor-level output, including C/C++ stderr.
captured = capfd.readouterr()
if "Falling back" in captured.err or "Fallback" in captured.err:
if "K" in pad_dim and layout != "NN":
pytest.xfail(
"Known CK_Tile limitation: K-padding with non-NN layouts may fall back to cuBLAS "
"(kPadK + ColMajor B bug, or CK_Tile stride alignment requirements)"
)
else:
pytest.fail(f"CK_Tile grouped GEMM fell back to cuBLAS:\n{captured.err}")


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ else()
gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even).");
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Loading
Loading