diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 85f183bf7..669238baf 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -5,6 +5,8 @@ ************************************************************************/ #include #include +#include +#include #include #include #include @@ -33,6 +35,98 @@ std::vector> test_case_sizes_mxfp8 = { {768, 3072, 4096}, }; +// ============================================================================ +// Production LLM shapes for MXFP8 GEMM testing. +// +// Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4) +// yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT) +// via ::testing::Combine. +// +// GemmPass selects the FP8 type combination: +// FWD: E4M3 x E4M3 -> BF16 +// DGRAD: E5M2 x E4M3 -> BF16 +// WGRAD: E4M3 x E5M2 -> BF16 +// ============================================================================ + +enum class GemmPass { FWD, DGRAD, WGRAD }; + +struct ShapeDef { + const char* label; + size_t dim1; // FWD/DGRAD: N, WGRAD: M + size_t dim2; // FWD/DGRAD: K, WGRAD: N + GemmPass pass; +}; + +std::ostream& operator<<(std::ostream& os, const ShapeDef& s) { + return os << s.label; +} + +static void resolve_mkn(const ShapeDef& s, size_t mbs, + size_t& m, size_t& k, size_t& n) { + size_t tokens = mbs * 4096; + switch (s.pass) { + case GemmPass::FWD: + case GemmPass::DGRAD: + m = tokens; n = s.dim1; k = s.dim2; + break; + case GemmPass::WGRAD: + m = s.dim1; n = s.dim2; k = tokens; + break; + } +} + +// DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head) +static const ShapeDef deepseek3_shapes[] = { + // Forward (M=tokens, N, K) + {"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD}, + {"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD}, + {"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD}, + {"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD}, + {"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD}, + {"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD}, + {"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD}, + {"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD}, + {"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD}, + {"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD}, + {"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD}, + // Dgrad (M=tokens, N, K) + {"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD}, + {"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD}, + {"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD}, + {"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD}, + {"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD}, + {"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD}, + {"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD}, + {"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD}, + // Wgrad (M, N, K=tokens) + {"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD}, + {"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD}, + {"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD}, + {"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD}, + {"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD}, + {"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD}, + {"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD}, +}; + +// Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head) +static const ShapeDef qwen3_shapes[] = { + // Forward (M=tokens, N, K) + {"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD}, + {"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD}, + {"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD}, + {"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD}, + // Dgrad (M=tokens, N, K) + {"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD}, + {"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD}, + {"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD}, + {"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD}, + // Wgrad (M, N, K=tokens) + {"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD}, + {"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD}, + {"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD}, + {"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD}, +}; + // A, B, Bias, Gelu, D // Bias type choose as bf16 in use_fp8, D_type otherwise // Gelu type the same as Bias_Type @@ -559,7 +653,9 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ template -void performDqTest(const TestParams ¶ms) { +void performDqTest(const TestParams ¶ms, + std::optional atol_override = std::nullopt, + std::optional rtol_override = std::nullopt) { DType atype = TypeInfo::dtype; DType btype = TypeInfo::dtype; DType dtype = TypeInfo::dtype; @@ -633,6 +729,10 @@ void performDqTest(const TestParams ¶ms) { //compare results auto [atol, rtol] = getTestTolerances(dtype, true, true); + if (atol_override) + atol = *atol_override; + if (rtol_override) + rtol = *rtol_override; compareResults("D", D, D_ref.rowwise_cpu_dptr(), true, atol, rtol); } #endif // __HIP_PLATFORM_AMD__ @@ -751,6 +851,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite, return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param)); }); +// ============================================================================ +// Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*') +// ============================================================================ + +// Known-failing GEMM shapes on gfx950 +static const std::set kGfx950Skips = { + "DeepSeek3_Linear1_fwd_mbs1_NT", + "DeepSeek3_Linear1_fwd_mbs2_NT", + "DeepSeek3_Linear1_fwd_mbs4_NT", + "DeepSeek3_LNLinear0_fwd_mbs4_NN", + "DeepSeek3_LNLinear0_fwd_mbs4_NT", + "DeepSeek3_attn_wgrad_mbs1_NN", + "Qwen3_LMHead_fwd_mbs2_NN", + "Qwen3_Router_fwd_mbs2_NT", + "Qwen3_LMHead_fwd_mbs4_TN", + "Qwen3_LMHead_fwd_mbs4_NN", + "Qwen3_LMHead_fwd_mbs4_NT", +}; + +// Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine. +using ProdGemmParam = std::tuple; + +class ProdDqGEMMTestSuite : public ::testing::TestWithParam {}; + +TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) { + const auto& shape = std::get<0>(GetParam()); + size_t mbs = std::get<1>(GetParam()); + const auto& layout = std::get<2>(GetParam()); + + std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs) + + "_" + TN(layout); + if (kGfx950Skips.count(name)) { + GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name; + } + + size_t m, k, n; + resolve_mkn(shape, mbs, m, k, n); + + TestParams params = {.m = m, .k = k, .n = n, + .use_bias = false, .use_gelu = false, + .transa = layout.first, .transb = layout.second, + .scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING}; + + // Production shapes use looser tolerances: the MXFP8 and bf16 reference + // GEMM use different internal accumulation paths, so results can differ + // by up to 1 ULP in bf16 (~1.5-2% relative). + const double prod_atol = 1e-3; + const double prod_rtol = 2e-2; + + switch (shape.pass) { + case GemmPass::FWD: + performDqTest(params, prod_atol, prod_rtol); + break; + case GemmPass::DGRAD: + performDqTest(params, prod_atol, prod_rtol); + break; + case GemmPass::WGRAD: + performDqTest(params, prod_atol, prod_rtol); + break; + } +} + +static auto prodTestName = [](const testing::TestParamInfo& info) { + const auto& shape = std::get<0>(info.param); + size_t mbs = std::get<1>(info.param); + const auto& layout = std::get<2>(info.param); + return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout); +}; + +INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite, + ::testing::Combine( + ::testing::ValuesIn(deepseek3_shapes), + ::testing::Values(size_t{1}, size_t{2}, size_t{4}), + ::testing::ValuesIn(kLayouts)), + prodTestName); + +INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite, + ::testing::Combine( + ::testing::ValuesIn(qwen3_shapes), + ::testing::Values(size_t{1}, size_t{2}, size_t{4}), + ::testing::ValuesIn(kLayouts)), + prodTestName); + TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) { const size_t rows = 128; const size_t cols = 256;