-
Notifications
You must be signed in to change notification settings - Fork 29
ck_tile grouped gemm: more padding options #574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 10 commits
95f984c
cfbc537
225c3dc
2939017
fa87ccc
01f62d0
aee2c4c
f830b89
2751b2a
a59a4ae
d68040b
a30e591
dff5635
fc4a101
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3075,6 +3075,144 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): | |
| os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) | ||
|
|
||
|
|
||
| if IS_HIP_EXTENSION: | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) | ||
| @pytest.mark.parametrize("layout", ["TN", "NN", "NT", "TT"]) | ||
| @pytest.mark.parametrize("accumulate", [False, True]) | ||
| @pytest.mark.parametrize( | ||
| "pad_dim", | ||
| ["K", "M", "N", "MK", "MKN"], | ||
| ids=lambda d: f"pad{d}", | ||
| ) | ||
| def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim, capfd): | ||
| """Test CK grouped GEMM with M, N, or K not aligned to CK tile size. | ||
|
|
||
| CK constraints for bf16/fp16: | ||
| - Contiguous dim of A/B must be dword-aligned (even for 2-byte types). | ||
| RowMajor: contiguous dim is cols (K for A, N for B). | ||
| ColMajor: contiguous dim is rows (M for A, K for B). | ||
| - K tile: 64, M tile: 256, N tile: 128/256 | ||
| """ | ||
| torch.manual_seed(0) | ||
| z = 8 | ||
|
|
||
| # Unaligned values per dimension (all satisfy CK vector-load constraints). | ||
| # K: even but not multiple of tile (64). Same for all groups. | ||
| # M: not multiples of tile (256), varies per group. | ||
| # N: multiple of 16 but not multiple of tile (128). | ||
| unaligned_k = 2016 | ||
| unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180] | ||
| unaligned_n = 2032 | ||
|
|
||
| # Aligned defaults. | ||
| k_aligned = 2048 | ||
| m_aligned = 256 | ||
| n_aligned = 2048 | ||
|
|
||
| # Select (un)aligned values based on pad_dim. | ||
| k_val = unaligned_k if "K" in pad_dim else k_aligned | ||
| m_vals = unaligned_m if "M" in pad_dim else [m_aligned] * z | ||
| n_val = unaligned_n if "N" in pad_dim else n_aligned | ||
|
|
||
| total_m = sum(m_vals) | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: better use monkeypath to make sure the envs are cleared if tests fails
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in dff5635 |
||
| os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" | ||
|
|
||
| if layout == "TN": | ||
| A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] | ||
| B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals] | ||
| out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] | ||
| out_ref = [o.clone() for o in torch.split(out[0], m_vals)] | ||
| m_splits = m_vals | ||
| grad = False | ||
| single_output = True | ||
| elif layout == "NN": | ||
| A = [torch.randn(k_val, n_val, dtype=dtype, device="cuda") for _ in range(z)] | ||
| B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals] | ||
| out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] | ||
| out_ref = [o.clone() for o in torch.split(out[0], m_vals)] | ||
| m_splits = m_vals | ||
| grad = True | ||
| single_output = True | ||
| elif layout == "NT": | ||
| A = list(torch.split( | ||
| torch.randn(total_m, k_val, dtype=dtype, device="cuda"), m_vals | ||
| )) | ||
| B = list(torch.split( | ||
| torch.randn(total_m, n_val, dtype=dtype, device="cuda"), m_vals | ||
| )) | ||
| out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] | ||
| out_ref = [o.clone() for o in out] | ||
| m_splits = m_vals | ||
| grad = True | ||
| single_output = False | ||
| else: # TT | ||
| A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)] | ||
| B = [torch.randn(k_val, m, dtype=dtype, device="cuda") for m in m_vals] | ||
| out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")] | ||
| out_ref = [o.clone() for o in torch.split(out[0], m_vals)] | ||
| m_splits = m_vals | ||
| grad = False | ||
| single_output = True | ||
|
|
||
| # Reference: individual GEMMs | ||
| for i in range(z): | ||
| if layout == "TT": | ||
| # general_gemm doesn't support TT; compute reference manually. | ||
| ref = B[i].T.to(torch.float32) @ A[i].T.to(torch.float32) | ||
| if accumulate: | ||
| out_ref[i] = (out_ref[i].to(torch.float32) + ref).to(dtype) | ||
| else: | ||
| out_ref[i] = ref.to(dtype) | ||
| else: | ||
| general_gemm( | ||
| A[i], | ||
| B[i], | ||
| dtype, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| layout=layout, | ||
| out=out_ref[i], | ||
| ) | ||
|
|
||
| if single_output: | ||
| out_ref = [torch.cat(out_ref)] | ||
|
|
||
| general_grouped_gemm( | ||
| A, | ||
| B, | ||
| out, | ||
| [None] * z, | ||
| dtype, | ||
| m_splits=m_splits, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| layout=layout, | ||
| single_output=single_output, | ||
| ) | ||
|
|
||
| for o, o_ref in zip(out, out_ref): | ||
| if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test itself is IS_HIP_EXTENSION only
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, removed the extra if in dff5635. |
||
| torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2) | ||
| else: | ||
| torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) | ||
|
|
||
| os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) | ||
| os.environ.pop("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", None) | ||
|
|
||
| # Check for CK fallback warnings from C++ (NVTE_WARN writes to std::cerr). | ||
| # capfd captures file-descriptor-level output, including C/C++ stderr. | ||
| captured = capfd.readouterr() | ||
| if "Falling back" in captured.err or "Fallback" in captured.err: | ||
| if "K" in pad_dim and layout != "NN": | ||
| pytest.xfail( | ||
| "Known CK_Tile limitation: K-padding with non-NN layouts may fall back to cuBLAS " | ||
| "(kPadK + ColMajor B bug, or CK_Tile stride alignment requirements)" | ||
| ) | ||
| else: | ||
| pytest.fail(f"CK_Tile grouped GEMM fell back to cuBLAS:\n{captured.err}") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("N", [32]) | ||
| @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think z should be derived as len of unaligned_m, or it should be asserted that they are equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in dff5635