Skip to content
Open
185 changes: 184 additions & 1 deletion tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
************************************************************************/
#include <cmath>
#include <iostream>
#include <optional>
#include <set>
#include <string>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -33,6 +35,98 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{768, 3072, 4096},
};

// ============================================================================
// Production LLM shapes for MXFP8 GEMM testing.
//
// Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4)
// yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT)
// via ::testing::Combine.
//
// GemmPass selects the FP8 type combination:
// FWD: E4M3 x E4M3 -> BF16
// DGRAD: E5M2 x E4M3 -> BF16
// WGRAD: E4M3 x E5M2 -> BF16
// ============================================================================

enum class GemmPass { FWD, DGRAD, WGRAD };

struct ShapeDef {
const char* label;
size_t dim1; // FWD/DGRAD: N, WGRAD: M
size_t dim2; // FWD/DGRAD: K, WGRAD: N
GemmPass pass;
};

std::ostream& operator<<(std::ostream& os, const ShapeDef& s) {
return os << s.label;
}

static void resolve_mkn(const ShapeDef& s, size_t mbs,
size_t& m, size_t& k, size_t& n) {
size_t tokens = mbs * 4096;
switch (s.pass) {
case GemmPass::FWD:
case GemmPass::DGRAD:
m = tokens; n = s.dim1; k = s.dim2;
break;
case GemmPass::WGRAD:
m = s.dim1; n = s.dim2; k = tokens;
break;
}
}

// DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head)
static const ShapeDef deepseek3_shapes[] = {
// Forward (M=tokens, N, K)
{"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD},
{"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD},
{"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD},
{"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD},
{"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD},
{"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD},
{"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD},
{"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD},
{"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD},
{"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD},
{"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD},
{"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD},
{"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD},
{"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD},
{"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD},
{"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD},
{"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD},
{"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD},
{"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD},
{"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD},
};

// Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head)
static const ShapeDef qwen3_shapes[] = {
// Forward (M=tokens, N, K)
{"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD},
{"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD},
{"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD},
{"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD},
{"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD},
{"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD},
{"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD},
{"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD},
{"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD},
{"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD},
};

// A, B, Bias, Gelu, D
// Bias type choose as bf16 in use_fp8, D_type otherwise
// Gelu type the same as Bias_Type
Expand Down Expand Up @@ -559,7 +653,9 @@ void performTest(const TestParams& params) {

#ifdef __HIP_PLATFORM_AMD__
template <typename A_Type, typename B_Type, typename D_Type>
void performDqTest(const TestParams &params) {
void performDqTest(const TestParams &params,
std::optional<double> atol_override = std::nullopt,
std::optional<double> rtol_override = std::nullopt) {
DType atype = TypeInfo<A_Type>::dtype;
DType btype = TypeInfo<B_Type>::dtype;
DType dtype = TypeInfo<D_Type>::dtype;
Expand Down Expand Up @@ -633,6 +729,10 @@ void performDqTest(const TestParams &params) {

//compare results
auto [atol, rtol] = getTestTolerances(dtype, true, true);
if (atol_override)
atol = *atol_override;
if (rtol_override)
rtol = *rtol_override;
compareResults("D", D, D_ref.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
}
#endif // __HIP_PLATFORM_AMD__
Expand Down Expand Up @@ -751,6 +851,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
});

// ============================================================================
// Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*')
// ============================================================================

// Known-failing GEMM shapes on gfx950
static const std::set<std::string> kGfx950Skips = {
"DeepSeek3_Linear1_fwd_mbs1_NT",
"DeepSeek3_Linear1_fwd_mbs2_NT",
"DeepSeek3_Linear1_fwd_mbs4_NT",
"DeepSeek3_LNLinear0_fwd_mbs4_NN",
"DeepSeek3_LNLinear0_fwd_mbs4_NT",
"DeepSeek3_attn_wgrad_mbs1_NN",
"Qwen3_LMHead_fwd_mbs2_NN",
"Qwen3_Router_fwd_mbs2_NT",
"Qwen3_LMHead_fwd_mbs4_TN",
"Qwen3_LMHead_fwd_mbs4_NN",
"Qwen3_LMHead_fwd_mbs4_NT",
};

// Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine.
using ProdGemmParam = std::tuple<ShapeDef, size_t, Layout>;

class ProdDqGEMMTestSuite : public ::testing::TestWithParam<ProdGemmParam> {};

TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) {
const auto& shape = std::get<0>(GetParam());
size_t mbs = std::get<1>(GetParam());
const auto& layout = std::get<2>(GetParam());

std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs)
+ "_" + TN(layout);
if (kGfx950Skips.count(name)) {
GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name;
}

size_t m, k, n;
resolve_mkn(shape, mbs, m, k, n);

TestParams params = {.m = m, .k = k, .n = n,
.use_bias = false, .use_gelu = false,
.transa = layout.first, .transb = layout.second,
.scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING};

// Production shapes use looser tolerances: the MXFP8 and bf16 reference
// GEMM use different internal accumulation paths, so results can differ
// by up to 1 ULP in bf16 (~1.5-2% relative).
const double prod_atol = 1e-3;
const double prod_rtol = 2e-2;

switch (shape.pass) {
case GemmPass::FWD:
performDqTest<fp8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::DGRAD:
performDqTest<bf8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::WGRAD:
performDqTest<fp8, bf8, bf16>(params, prod_atol, prod_rtol);
break;
}
}

static auto prodTestName = [](const testing::TestParamInfo<ProdGemmParam>& info) {
const auto& shape = std::get<0>(info.param);
size_t mbs = std::get<1>(info.param);
const auto& layout = std::get<2>(info.param);
return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout);
};

INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(deepseek3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(qwen3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
const size_t rows = 128;
const size_t cols = 256;
Expand Down
Loading