From 4fdcdd594a9656b3d73c0617643b20343dc3f3ff Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 30 Dec 2025 19:01:33 +0530 Subject: [PATCH] Add ROCm/rocWMMA support for RDNA3 (gfx1151) with AMD Windows setup guide --- .gitignore | 87 +++- README_AMD_WINDOWS.md | 185 +++++++ csrc/fused/rocm/dispatch_utils.h | 112 ++++ csrc/fused/rocm/fused.cu | 431 ++++++++++++++++ csrc/fused/rocm/fused.h | 30 ++ csrc/fused/rocm/pybind_rocm.cpp | 33 ++ csrc/qattn/rocm/attn_rocm.h | 87 ++++ csrc/qattn/rocm/dispatch_utils.h | 73 +++ csrc/qattn/rocm/launch_sgattn_f16.cu | 300 +++++++++++ csrc/qattn/rocm/pybind_rocm.cpp | 42 ++ csrc/qattn/rocm/sgattn_f16.cu | 739 +++++++++++++++++++++++++++ csrc/reduction_utils.h | 163 ++++++ csrc/reduction_utils_hip.h | 164 ++++++ csrc/utils.cuh | 38 ++ setup.py | 417 +++++++++------ spas_sage_attn/core.py | 202 ++++++-- 16 files changed, 2912 insertions(+), 191 deletions(-) create mode 100644 README_AMD_WINDOWS.md create mode 100644 csrc/fused/rocm/dispatch_utils.h create mode 100644 csrc/fused/rocm/fused.cu create mode 100644 csrc/fused/rocm/fused.h create mode 100644 csrc/fused/rocm/pybind_rocm.cpp create mode 100644 csrc/qattn/rocm/attn_rocm.h create mode 100644 csrc/qattn/rocm/dispatch_utils.h create mode 100644 csrc/qattn/rocm/launch_sgattn_f16.cu create mode 100644 csrc/qattn/rocm/pybind_rocm.cpp create mode 100644 csrc/qattn/rocm/sgattn_f16.cu create mode 100644 csrc/reduction_utils.h create mode 100644 csrc/reduction_utils_hip.h create mode 100644 csrc/utils.cuh diff --git a/.gitignore b/.gitignore index 2bd6f2b..4069f7f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,80 @@ -__pycache__ -spas_sage_attn.egg-info -*.pkl -/dist -/build +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions *.so -.DS_Store -inst*.cu -/unit_test + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +*.manifest +*.spec + +# pip +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.idea/ +.vscode/ +.cursor/ +*.swp +*.swo + +# ROCm cloned libraries +/third_party/ + +# HIP generated files +*.hip + +# Build artifacts +*.o +*.obj + +# Instantiation generated files +csrc/qattn/instantiations_sm80/*.cu +csrc/qattn/instantiations_sm89/*.cu +csrc/qattn/instantiations_sm90/*.cu diff --git a/README_AMD_WINDOWS.md b/README_AMD_WINDOWS.md new file mode 100644 index 0000000..5d322a6 --- /dev/null +++ b/README_AMD_WINDOWS.md @@ -0,0 +1,185 @@ +# SpargeAttn - AMD ROCm on Windows Setup Guide + +This guide explains how to build and run SpargeAttn on Windows with AMD GPUs using ROCm. + +> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues. + +## Supported Hardware + +SpargeAttn on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1103, gfx1151). + +## Prerequisites + +- Windows 10/11 +- Python 3.11, 3.12, or 3.13 +- Visual Studio 2022 with C++ build tools +- AMD Adrenaline driver (latest recommended) + +## Installation + +### 1. Install ROCm and PyTorch from TheRock + +Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture. + +#### Create a Virtual Environment + +```powershell +python -m venv venv +.\venv\Scripts\Activate.ps1 +``` + +#### Install PyTorch (includes ROCm SDK as dependency) + +For **gfx1151** (AMD Strix Halo iGPU): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision +``` + +For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision +``` + +For **gfx120X** (RX 9060, RX 9070): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision +``` + +#### Initialize ROCm SDK + +```powershell +rocm-sdk init +``` + +#### Install Triton with AMD Windows Support + +```powershell +pip install triton-windows +``` + +### 2. Set Environment Variables + +Open a PowerShell terminal and run: + +```powershell +# Activate Visual Studio environment +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } + +# Activate the virtual environment +.\venv\Scripts\Activate.ps1 + +# Set ROCm paths using rocm-sdk +$ROCM_ROOT = (rocm-sdk path --root).Trim() +$ROCM_BIN = (rocm-sdk path --bin).Trim() +$env:ROCM_HOME = $ROCM_ROOT +$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH" + +# Set compiler and build settings +$env:CC = "clang-cl" +$env:CXX = "clang-cl" +$env:DISTUTILS_USE_SDK = "1" + +# Enable experimental features +$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE" +$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1" +``` + +### 3. Build and Install SpargeAttn + +```powershell +cd +pip install --no-build-isolation -v . +``` + +## Testing + +### Quick Correctness Test + +Run this script to verify SpargeAttn is working correctly by comparing against PyTorch SDPA: + +```python +import torch +import torch.nn.functional as F +from spas_sage_attn.core import spas_sage_attn_meansim_cuda + +device = torch.device('cuda') + +# Create random test tensors (use float16 for ROCm compatibility) +q = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device) +k = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device) +v = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device) + +# Compute reference output using PyTorch SDPA +with torch.no_grad(): + sdpa = F.scaled_dot_product_attention(q.float(), k.float(), v.float()).to(torch.float16) + +# Compute SpargeAttn output (with 100% sparsity = dense attention) +sparge = spas_sage_attn_meansim_cuda( + q, k, v, + is_causal=False, + smooth_k=False, + simthreshd1=0.0, # No similarity threshold (keep all blocks) + cdfthreshd=1.0, # 100% sparsity + pvthreshd=0, + tensor_layout='HND' +) + +# Compare outputs using cosine similarity +cos = F.cosine_similarity( + sdpa.flatten().float().unsqueeze(0), + sparge.flatten().float().unsqueeze(0) +) +print(f'Cosine similarity: {cos.item():.6f}') # Should be ~0.9999 +``` + +Save this as `test_spargeattn.py` and run: + +```powershell +python test_spargeattn.py +``` + +Expected output: +``` +Cosine similarity: 0.999900 +``` + +A cosine similarity above 0.999 indicates the kernel is working correctly. + +## Performance Notes + +At L=4096, D=128, bf16 vs PyTorch SDPA (with aotriton): + +| Sparsity | Time | Speedup vs SDPA | +|----------|------|-----------------| +| 100% | 33.0 ms | 0.18x | +| 50% | 13.7 ms | 0.43x | +| 25% | 7.4 ms | 0.79x | +| **10%** | **3.2 ms** | **1.81x** | +| 5% | 1.8 ms | 3.26x | +| 2% | 1.0 ms | 6.07x | + +**Break-even point**: ~20-25% sparsity. Below that, SpargeAttn is faster than dense SDPA. + +## Known Issues + +1. **No FP8 support on RDNA3** - rocWMMA on gfx11xx doesn't support FP8, so FP16/BF16 is used for V. + +2. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless. + +## Troubleshooting + +### "LoadLibrary failed" or "cannot find amdhip64.dll" + +Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages. + +### "LINK : fatal error LNK1104: cannot open file 'python312.lib'" + +Ensure Visual Studio environment is activated before building: +```powershell +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } +``` + +### "PermissionError" when compiling Triton kernels + +This is a known Windows issue with temp file handling. Make sure you're using the latest `triton-windows` package (`pip install --upgrade triton-windows`). + diff --git a/csrc/fused/rocm/dispatch_utils.h b/csrc/fused/rocm/dispatch_utils.h new file mode 100644 index 0000000..14de852 --- /dev/null +++ b/csrc/fused/rocm/dispatch_utils.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + __VA_ARGS__ \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head dim: " << int(head_dim); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ + if (is_causal == 1) { \ + constexpr bool IS_CAUSAL = true; \ + __VA_ARGS__ \ + } else if (is_causal == 0) { \ + constexpr bool IS_CAUSAL = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(is_causal); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ + if (qk_quant_gran == 2) { \ + constexpr int QK_QUANT_GRAN = 2; \ + __VA_ARGS__ \ + } else if (qk_quant_gran == 3) { \ + constexpr int QK_QUANT_GRAN = 3; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \ + if (return_lse == 1) { \ + constexpr bool RETURN_LSE = true; \ + __VA_ARGS__ \ + } else if (return_lse == 0) { \ + constexpr bool RETURN_LSE = false; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported causal mode: " << int(return_lse); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = hip_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr int BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } else if (block_size == 128) { \ + constexpr int BLOCK_SIZE = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported block_size " << int(block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } + +#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \ + if (warp_block_size == 16) { \ + constexpr int WARP_BLOCK_SIZE = 16; \ + __VA_ARGS__ \ + } else if (warp_block_size == 32) { \ + constexpr int WARP_BLOCK_SIZE = 32; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported warp_block_size " << int(warp_block_size); \ + throw std::invalid_argument(err_msg.str()); \ + } diff --git a/csrc/fused/rocm/fused.cu b/csrc/fused/rocm/fused.cu new file mode 100644 index 0000000..ac1dcb1 --- /dev/null +++ b/csrc/fused/rocm/fused.cu @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include +#include + +#include "dispatch_utils.h" +#include "../../utils.cuh" +#include "../../reduction_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +enum class QuantType +{ + kInt8, + kInt4, +}; + +__device__ __forceinline__ float u32_as_f32(uint32_t u) { + union { uint32_t u; float f; } v{u}; return v.f; +} + +__device__ __forceinline__ uint16_t bf16_bits(__hip_bfloat16 x) { + return *reinterpret_cast(&x); +} + +// ========== to-float ========== +template +__device__ __forceinline__ float convert_to_float(T val); + +// __half → float +template <> +__device__ __forceinline__ float convert_to_float<__half>(__half v) { + return __half2float(v); +} + +// __hip_bfloat16 → float (bit-level concatenation, high 16 bits) +template <> +__device__ __forceinline__ float convert_to_float<__hip_bfloat16>(__hip_bfloat16 v) { + uint16_t hi = bf16_bits(v); + return u32_as_f32(uint32_t(hi) << 16); +} + +// hip_bfloat16 → float (via reinterpret as __hip_bfloat16) +template <> +__device__ __forceinline__ float convert_to_float(hip_bfloat16 v) { + return convert_to_float(*reinterpret_cast(&v)); +} + +namespace detail { + + struct vec16_t { float x, y, z, w; }; + + template + __device__ __forceinline__ void predicated_g2s_16B(T* smem_dst, const T* gmem_src, bool pred) { + if (pred) { + *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src); + } else if constexpr (PadZero) { + *reinterpret_cast(smem_dst) = vec16_t{0.f, 0.f, 0.f, 0.f}; + } + } + + __device__ __forceinline__ void store_8fp8(const uint32_t* __restrict__ fp8x4, + int8_t* __restrict__ out) { + *reinterpret_cast(out) = *reinterpret_cast(fp8x4); + } + + __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t* dest, float* s0, float* s1) { + + #ifdef __ROCM_ARCH_GFX942 + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + #else + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3); + #endif + + + *dest = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | ((uint32_t)b3 << 24); + } + +} // namespace detail + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + // static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = vllm::blockReduceMax(max_val); + float block_min_val = vllm::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = vllm::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } + + detail::floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + detail::floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + detail::store_8fp8(&x_val_fp8[0], output_ptr_base + i * gmem_stride); + } +} + + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + + uint32_t smem_load_row = thread_id / num_threads_per_token; + + detail::predicated_g2s_16B( + &shared_load[smem_load_row][ (thread_id % num_threads_per_token) * pack_size ], + input_ptr_base, + thread_base_token < num_tokens); + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + + *reinterpret_cast(output_ptr_base) = + *reinterpret_cast( + &shared_store[ thread_id / num_threads_per_cta ] + [ (thread_id % num_threads_per_cta) * pack_size ]); +} + + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, false, c_type><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + nullptr, + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.stride(0), scale.stride(1) + ); + }); +} + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.scalar_type(); + auto output_dtype = output.scalar_type(); + + TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + TransposePadPermuteKernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} diff --git a/csrc/fused/rocm/fused.h b/csrc/fused/rocm/fused.h new file mode 100644 index 0000000..e09fcae --- /dev/null +++ b/csrc/fused/rocm/fused.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout); + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout); \ No newline at end of file diff --git a/csrc/fused/rocm/pybind_rocm.cpp b/csrc/fused/rocm/pybind_rocm.cpp new file mode 100644 index 0000000..d29b8b7 --- /dev/null +++ b/csrc/fused/rocm/pybind_rocm.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "fused.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ +// m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); +// m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); +// m.def("quant_per_block_int8_fuse_sub_mean_cuda", py::overload_cast(&quant_per_block_int8_fuse_sub_mean_cuda), "quant_per_block_int8_fuse_sub_mean_cuda"); +// m.def("quant_per_warp_int8_cuda", py::overload_cast(&quant_per_warp_int8_cuda), "quant_per_warp_int8_cuda"); + +// m.def("sub_mean_cuda", py::overload_cast(&sub_mean_cuda), "sub_mean_cuda"); + + m.def("transpose_pad_permute_cuda", py::overload_cast(&transpose_pad_permute_cuda), "transpose_pad_permute_cuda"); + m.def("scale_fuse_quant_cuda", py::overload_cast(&scale_fuse_quant_cuda), "scale_fuse_quant_cuda"); +// m.def("mean_scale_fuse_quant_cuda", py::overload_cast(&mean_scale_fuse_quant_cuda), "mean_scale_fuse_quant_cuda"); +} \ No newline at end of file diff --git a/csrc/qattn/rocm/attn_rocm.h b/csrc/qattn/rocm/attn_rocm.h new file mode 100644 index 0000000..cc8b7e5 --- /dev/null +++ b/csrc/qattn/rocm/attn_rocm.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2025 by SpargeAttn team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// FP16 V matrix - block sparse attention for all ROCm architectures +void qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor query_scale, + torch::Tensor key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale); + +// FP16 V matrix with PV threshold - block sparse attention for all ROCm architectures +torch::Tensor qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor pv_threshold, + torch::Tensor query_scale, + torch::Tensor key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_pv_count); + +#if defined(SA_ARCH_MI_SERIES) +// FP8 V matrix - block sparse attention for MI-series GPUs +void qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale); + +// FP8 V matrix with PV threshold - block sparse attention for MI-series GPUs +torch::Tensor qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor pv_threshold, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_pv_count); +#endif diff --git a/csrc/qattn/rocm/dispatch_utils.h b/csrc/qattn/rocm/dispatch_utils.h new file mode 100644 index 0000000..bcd9ae2 --- /dev/null +++ b/csrc/qattn/rocm/dispatch_utils.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025 by SpargeAttn team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() { \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype: ", pytorch_dtype); \ + } \ + }() + +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + [&]() { \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + return __VA_ARGS__(); \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported head_dim: ", head_dim); \ + } \ + }() + +#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ + [&]() { \ + if (is_causal) { \ + constexpr bool IS_CAUSAL = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool IS_CAUSAL = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ + [&]() { \ + if (qk_quant_gran == 1) { \ + constexpr int QK_QUANT_GRAN = 1; \ + return __VA_ARGS__(); \ + } else if (qk_quant_gran == 2) { \ + constexpr int QK_QUANT_GRAN = 2; \ + return __VA_ARGS__(); \ + } else if (qk_quant_gran == 3) { \ + constexpr int QK_QUANT_GRAN = 3; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported qk_quant_gran: ", qk_quant_gran); \ + } \ + }() diff --git a/csrc/qattn/rocm/launch_sgattn_f16.cu b/csrc/qattn/rocm/launch_sgattn_f16.cu new file mode 100644 index 0000000..1c36f86 --- /dev/null +++ b/csrc/qattn/rocm/launch_sgattn_f16.cu @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2025 by SpargeAttn team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "dispatch_utils.h" +#include "attn_rocm.h" + +// Forward declaration of kernel launcher - DTypeV is the value/output type (half or hip_bfloat16) +template +void SpargeAttentionROCmF16Dispatched( + int8_t* Q, int8_t* K, DTypeV* V, DTypeV* O, + int32_t* PV_Count, int32_t* Lut, int32_t* Valid_Block_Num, float* PV_Threshold, + float* Q_scale, float* K_scale, + const uint32_t batch_size, const uint32_t qo_len, const uint32_t kv_len, + const uint32_t num_qo_heads, const uint32_t num_kv_heads, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale); + +inline uint32_t div_ceil(uint32_t a, uint32_t b) { return (a + b - 1) / b; } + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LASTDIM_CONTIGUOUS(x) TORCH_CHECK(x.stride(-1) == 1, #x " must have contiguous last dim") +#define CHECK_DTYPE(x, dtype) TORCH_CHECK(x.scalar_type() == dtype, #x " must have dtype " #dtype) +#define CHECK_DIMS(x, dims) TORCH_CHECK(x.dim() == dims, #x " must have " #dims " dimensions") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == c10::IntArrayRef({__VA_ARGS__}), #x " has wrong shape") + +void qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor query_scale, + torch::Tensor key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(lut); + CHECK_CUDA(valid_block_num); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_LASTDIM_CONTIGUOUS(value); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(lut); + CHECK_CONTIGUOUS(valid_block_num); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + TORCH_CHECK(value.scalar_type() == torch::kFloat16 || value.scalar_type() == torch::kBFloat16, + "value must be Float16 or BFloat16"); + TORCH_CHECK(output.scalar_type() == value.scalar_type(), + "output must have same dtype as value"); + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + + const int batch_size = query.size(0); + const int head_dim = query.size(3); + + int stride_bz_q = query.stride(0); + int stride_bz_k = key.stride(0); + int stride_bz_v = value.stride(0); + int stride_bz_o = output.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; + int stride_h_q, stride_h_k, stride_h_v, stride_h_o; + + // tensor_layout: 0 = [batch, seq, head, dim], 1 = [batch, head, seq, dim] + if (tensor_layout == 0) { + qo_len = query.size(1); + kv_len = key.size(1); + num_qo_heads = query.size(2); + num_kv_heads = key.size(2); + stride_seq_q = query.stride(1); + stride_seq_k = key.stride(1); + stride_seq_v = value.stride(1); + stride_seq_o = output.stride(1); + stride_h_q = query.stride(2); + stride_h_k = key.stride(2); + stride_h_v = value.stride(2); + stride_h_o = output.stride(2); + } else { + qo_len = query.size(2); + kv_len = key.size(2); + num_qo_heads = query.size(1); + num_kv_heads = key.size(1); + stride_seq_q = query.stride(2); + stride_seq_k = key.stride(2); + stride_seq_v = value.stride(2); + stride_seq_o = output.stride(2); + stride_h_q = query.stride(1); + stride_h_k = key.stride(1); + stride_h_v = value.stride(1); + stride_h_o = output.stride(1); + } + + TORCH_CHECK(head_dim == 64 || head_dim == 128, "head_dim must be 64 or 128, got ", head_dim); + TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads must be divisible by num_kv_heads"); + TORCH_CHECK(qk_quant_gran >= 1 && qk_quant_gran <= 2, "qk_quant_gran must be 1 (per-block) or 2 (per-warp)"); + + const bool is_bf16 = value.scalar_type() == torch::kBFloat16; + + // Dispatch based on head_dim, causal, quantization granularity, and dtype + // For head_dim=64: CTA_Q=64, CTA_K=64 + // For head_dim=128: CTA_Q=32, CTA_K=32 (smaller tiles reduce register pressure) + // Note: CTA_K must match BLKK used in Python quantization code + + #define DISPATCH_KERNEL_64(QUANT_GRAN, CAUSAL, DTYPE) \ + SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, QUANT_GRAN, true, 0, DTYPE, CAUSAL, false>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + nullptr, \ + reinterpret_cast(lut.data_ptr()), \ + reinterpret_cast(valid_block_num.data_ptr()), \ + nullptr, \ + reinterpret_cast(query_scale.data_ptr()), \ + reinterpret_cast(key_scale.data_ptr()), \ + batch_size, qo_len, kv_len, num_qo_heads, num_kv_heads, \ + stride_bz_q, stride_seq_q, stride_h_q, \ + stride_bz_k, stride_seq_k, stride_h_k, \ + stride_bz_v, stride_seq_v, stride_h_v, \ + stride_bz_o, stride_seq_o, stride_h_o, \ + sm_scale) + + // CTA_K=32 for head_dim=128 (default) + #define DISPATCH_KERNEL_128_K32(QUANT_GRAN, CAUSAL, DTYPE) \ + SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, QUANT_GRAN, true, 0, DTYPE, CAUSAL, false>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + nullptr, \ + reinterpret_cast(lut.data_ptr()), \ + reinterpret_cast(valid_block_num.data_ptr()), \ + nullptr, \ + reinterpret_cast(query_scale.data_ptr()), \ + reinterpret_cast(key_scale.data_ptr()), \ + batch_size, qo_len, kv_len, num_qo_heads, num_kv_heads, \ + stride_bz_q, stride_seq_q, stride_h_q, \ + stride_bz_k, stride_seq_k, stride_h_k, \ + stride_bz_v, stride_seq_v, stride_h_v, \ + stride_bz_o, stride_seq_o, stride_h_o, \ + sm_scale) + + // CTA_K=16 for head_dim=128 (experimental - finer granularity) + #define DISPATCH_KERNEL_128_K16(QUANT_GRAN, CAUSAL, DTYPE) \ + SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, QUANT_GRAN, true, 0, DTYPE, CAUSAL, false>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + nullptr, \ + reinterpret_cast(lut.data_ptr()), \ + reinterpret_cast(valid_block_num.data_ptr()), \ + nullptr, \ + reinterpret_cast(query_scale.data_ptr()), \ + reinterpret_cast(key_scale.data_ptr()), \ + batch_size, qo_len, kv_len, num_qo_heads, num_kv_heads, \ + stride_bz_q, stride_seq_q, stride_h_q, \ + stride_bz_k, stride_seq_k, stride_h_k, \ + stride_bz_v, stride_seq_v, stride_h_v, \ + stride_bz_o, stride_seq_o, stride_h_o, \ + sm_scale) + + // Use CTA_K=16 for testing (finer granularity, potentially lower overhead) + #define DISPATCH_KERNEL_128 DISPATCH_KERNEL_128_K16 + + if (head_dim == 64) { + if (is_bf16) { + if (qk_quant_gran == 1) { + if (is_causal) { + DISPATCH_KERNEL_64(1, true, hip_bfloat16); + } else { + DISPATCH_KERNEL_64(1, false, hip_bfloat16); + } + } else { + if (is_causal) { + DISPATCH_KERNEL_64(2, true, hip_bfloat16); + } else { + DISPATCH_KERNEL_64(2, false, hip_bfloat16); + } + } + } else { // fp16 + if (qk_quant_gran == 1) { + if (is_causal) { + DISPATCH_KERNEL_64(1, true, half); + } else { + DISPATCH_KERNEL_64(1, false, half); + } + } else { + if (is_causal) { + DISPATCH_KERNEL_64(2, true, half); + } else { + DISPATCH_KERNEL_64(2, false, half); + } + } + } + } else { // head_dim == 128 + if (is_bf16) { + if (qk_quant_gran == 1) { + if (is_causal) { + DISPATCH_KERNEL_128(1, true, hip_bfloat16); + } else { + DISPATCH_KERNEL_128(1, false, hip_bfloat16); + } + } else { + if (is_causal) { + DISPATCH_KERNEL_128(2, true, hip_bfloat16); + } else { + DISPATCH_KERNEL_128(2, false, hip_bfloat16); + } + } + } else { // fp16 + if (qk_quant_gran == 1) { + if (is_causal) { + DISPATCH_KERNEL_128(1, true, half); + } else { + DISPATCH_KERNEL_128(1, false, half); + } + } else { + if (is_causal) { + DISPATCH_KERNEL_128(2, true, half); + } else { + DISPATCH_KERNEL_128(2, false, half); + } + } + } + } + + #undef DISPATCH_KERNEL_64 + #undef DISPATCH_KERNEL_128 + #undef DISPATCH_KERNEL_128_K32 + #undef DISPATCH_KERNEL_128_K16 +} + +torch::Tensor qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor lut, + torch::Tensor valid_block_num, + torch::Tensor pv_threshold, + torch::Tensor query_scale, + torch::Tensor key_scale, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_pv_count) +{ + // Call the basic version - PV threshold support can be added later + qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf( + query, key, value, output, lut, valid_block_num, + query_scale, key_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale); + + // Return empty tensor for pv_count (not implemented yet) + return torch::empty({0}, query.options().dtype(torch::kInt32)); +} diff --git a/csrc/qattn/rocm/pybind_rocm.cpp b/csrc/qattn/rocm/pybind_rocm.cpp new file mode 100644 index 0000000..089ec06 --- /dev/null +++ b/csrc/qattn/rocm/pybind_rocm.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025 by SpargeAttn team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "attn_rocm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // FP16 V matrix - works on all ROCm architectures + m.def("qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf", + &qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf, + "QK int8 SV f16 block sparse attention (all ROCm GPUs)"); + + m.def("qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold", + &qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold, + "QK int8 SV f16 block sparse attention with PV threshold (all ROCm GPUs)"); + +#if defined(SA_ARCH_MI_SERIES) + // FP8 V matrix - MI series only + m.def("qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale", + &qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale, + "QK int8 SV f8 block sparse attention (MI series GPUs)"); + + m.def("qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold", + &qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold, + "QK int8 SV f8 block sparse attention with PV threshold (MI series GPUs)"); +#endif +} diff --git a/csrc/qattn/rocm/sgattn_f16.cu b/csrc/qattn/rocm/sgattn_f16.cu new file mode 100644 index 0000000..729751d --- /dev/null +++ b/csrc/qattn/rocm/sgattn_f16.cu @@ -0,0 +1,739 @@ +/* + * Copyright (c) 2025 by SpargeAttn team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Register-based attention kernel for AMD RDNA3 (gfx11) using rocWMMA. + * + * RDNA3 WMMA D[row][col] layout (from AMD Matrix Instruction Calculator): + * fragment.x[reg] where reg ∈ [0,7]: + * row = reg * 2 + (lane_id >= 16 ? 1 : 0) + * col = lane_id % 16 + * + * Each lane handles ONE column across 8 rows (either even or odd rows). + * Lanes 0-15 handle even rows (0,2,4,...,14) + * Lanes 16-31 handle odd rows (1,3,5,...,15) + */ + +#include +#include +#include +#include + +using namespace rocwmma; + +namespace gfx11Params { + constexpr uint32_t WAVE_SIZE = 32u; + constexpr uint32_t WMMA_M = 16u; + constexpr uint32_t WMMA_N = 16u; + constexpr uint32_t WMMA_K_INT8 = 16u; // RDNA3 WMMA K for INT8 + constexpr uint32_t WMMA_K_FP16 = 16u; // RDNA3 WMMA K for FP16/BF16 +} + +constexpr float LOG2E = 1.44269504088896340736f; +#define div_ceil_hip(M, N) (((M) + (N) - 1) / (N)) + +enum class QuantGranularity { + kPerTensor = 0, + kPerBlock = 1, + kPerWarp = 2, + kPerThread = 3, +}; + +enum class MaskMode { + kNone = 0, + kCausal = 1, +}; + +// Type traits for FP16/BF16 support +template struct TypeTraits; + +template<> +struct TypeTraits { + __device__ static float to_float(half val) { return __half2float(val); } + __device__ static half from_float(float val) { return __float2half(val); } +}; + +template<> +struct TypeTraits { + __device__ static float to_float(hip_bfloat16 val) { return static_cast(val); } + __device__ static hip_bfloat16 from_float(float val) { return hip_bfloat16(val); } +}; + +// Fragment types for QK phase: INT8 M16N16K16 +using FragA_QK = fragment; +using FragB_QK = fragment; +using FragAcc_QK = fragment; + +// Fragment types for SV phase: FP16 M16N16K16 +template struct SVFragmentTypes; + +template<> +struct SVFragmentTypes { + using FragA = fragment; + using FragB = fragment; + using FragAcc = fragment; +}; + +template<> +struct SVFragmentTypes { + using FragA = fragment; + using FragB = fragment; + using FragAcc = fragment; +}; + +/* + * RDNA3 WMMA element-to-matrix mapping helpers. + * Based on AMD Matrix Instruction Calculator output. + */ +__device__ __forceinline__ uint32_t wmma_elem_row(uint32_t reg, uint32_t lane_id) { + // reg ∈ [0,7], row = reg * 2 + (lane_id >= 16 ? 1 : 0) + return reg * 2 + (lane_id >> 4); +} + +__device__ __forceinline__ uint32_t wmma_elem_col(uint32_t lane_id) { + // col = lane_id % 16 + return lane_id & 15; +} + +/* + * Register-based attention kernel. + * + * Key insight for RDNA3 softmax: + * - Each lane owns one column across 8 rows + * - For row-wise softmax, we need to reduce across 16 lanes (0-15 for even rows, 16-31 for odd) + * - Use __shfl_xor within the lane group for fast reduction + */ +template +__global__ void __launch_bounds__(NUM_THREADS) +qk_int_sv_f16_block_sparse_attn_kernel_rocm( + int8_t* __restrict__ Q, + int8_t* __restrict__ K, + DTypeV* __restrict__ V, + DTypeV* __restrict__ O, + int32_t* __restrict__ PV_Count, + int32_t* __restrict__ Lut, + int32_t* __restrict__ Valid_Block_Num, + float* __restrict__ PV_Threshold, + float* __restrict__ Q_scale, + float* __restrict__ K_scale, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + using namespace gfx11Params; + using Traits = TypeTraits; + using SVFrags = SVFragmentTypes; + using FragA_SV_T = typename SVFrags::FragA; + using FragB_SV_T = typename SVFrags::FragB; + using FragAcc_SV_T = typename SVFrags::FragAcc; + + constexpr uint32_t NUM_WARPS_Q = CTA_Q / WARP_Q; + constexpr uint32_t NUM_WARPS_K = 1; // Always 1 for our configuration + constexpr uint32_t NUM_TILES_K = CTA_K / WMMA_N; // 4 for CTA_K=64 + constexpr uint32_t NUM_TILES_V = HEAD_DIM / WMMA_N; // 8 for HD=128, 4 for HD=64 + constexpr uint32_t NUM_K_ITERS = HEAD_DIM / WMMA_K_INT8; // 8 for HD=128 (K=16) + constexpr uint32_t NUM_SV_ITERS = CTA_K / WMMA_K_FP16; // 4 for CTA_K=64, K=16 + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t head_id = blockIdx.y; + const uint32_t num_qo_heads = gridDim.y; + + const uint32_t tid = threadIdx.x + threadIdx.y * blockDim.x; + const uint32_t warp_id = tid / WAVE_SIZE; + const uint32_t lane_id = tid % WAVE_SIZE; + const uint32_t warp_idx_q = warp_id / NUM_WARPS_K; + const uint32_t warp_idx_k = warp_id % NUM_WARPS_K; + + // For RDNA3 WMMA layout: lane 0-15 handle even rows, lane 16-31 handle odd rows + const uint32_t is_odd_row_lane = lane_id >> 4; // 0 for lanes 0-15, 1 for lanes 16-31 + const uint32_t lane_col = lane_id & 15; // Column index within 16x16 tile + + sm_scale *= LOG2E; + + const uint32_t num_block_q = gridDim.x; + const uint32_t num_block_k = div_ceil_hip(kv_len, CTA_K); + const uint32_t num_iterations = Valid_Block_Num[batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx]; + + if (num_iterations == 0) return; + + const int32_t* lut_ptr = Lut + batch_id * num_qo_heads * num_block_q * num_block_k + + head_id * num_block_q * num_block_k + bx * num_block_k; + + // Shared memory layout + extern __shared__ char smem[]; + + int8_t* smem_Q = reinterpret_cast(smem); + int8_t* smem_K = smem_Q + HEAD_DIM * CTA_Q; + DTypeV* smem_V = reinterpret_cast(smem_K + HEAD_DIM * CTA_K); + DTypeV* smem_S = smem_V + CTA_K * HEAD_DIM; // For storing S after softmax + + const uint32_t q_start = bx * CTA_Q; + const uint32_t q_tile_row = warp_idx_q * WMMA_M; // Each warp handles 16 Q rows + + // ======================================== + // REGISTER-BASED STATE + // ======================================== + + // Output accumulator: RO[NUM_TILES_V][8] - 8 elements per 16x16 tile + float RO[NUM_TILES_V][8]; + + // Per-row softmax state: m (max) and d (sum) + // Each lane owns 8 elements in the same column but different rows + // We need per-row max/sum, so each element has its own m and d + float m[8]; // max for each of the 8 rows this lane owns + float d[8]; // sum for each of the 8 rows this lane owns + + // Initialize RO, m, d + #pragma unroll + for (uint32_t fv = 0; fv < NUM_TILES_V; fv++) { + #pragma unroll + for (uint32_t i = 0; i < 8; i++) { + RO[fv][i] = 0.0f; + } + } + #pragma unroll + for (uint32_t i = 0; i < 8; i++) { + m[i] = -5000000.0f; + d[i] = 0.0f; + } + + // Load Q to shared memory (col-major for WMMA A) + for (uint32_t i = tid; i < CTA_Q * HEAD_DIM; i += NUM_THREADS) { + uint32_t q_row = i % CTA_Q; + uint32_t q_col = i / CTA_Q; + uint32_t q_idx = q_start + q_row; + int8_t val = 0; + if (q_idx < qo_len) { + val = Q[batch_id * stride_bz_q + q_idx * stride_seq_q + head_id * stride_h_q + q_col]; + } + smem_Q[i] = val; + } + __syncthreads(); + + // Get Q scale + float q_scale_val; + if constexpr (Q_GRAN == QuantGranularity::kPerBlock) { + q_scale_val = Q_scale[batch_id * num_qo_heads * num_block_q + head_id * num_block_q + bx]; + } else if constexpr (Q_GRAN == QuantGranularity::kPerWarp) { + const uint32_t num_warp_block_q = num_block_q * NUM_WARPS_Q; + q_scale_val = Q_scale[batch_id * num_qo_heads * num_warp_block_q + head_id * num_warp_block_q + bx * NUM_WARPS_Q + warp_idx_q]; + } + + // Main loop over K blocks + uint32_t k_block_idx = 0; + for (uint32_t iter = 0; iter < num_iterations; iter++) { + k_block_idx += lut_ptr[iter]; + uint32_t k_start = k_block_idx * CTA_K; + + // Load K to shared memory (as K^T) + for (uint32_t i = tid; i < CTA_K * HEAD_DIM; i += NUM_THREADS) { + uint32_t n = i % CTA_K; + uint32_t k = i / CTA_K; + uint32_t k_idx = k_start + n; + int8_t val = 0; + if (k_idx < kv_len) { + val = K[batch_id * stride_bz_k + k_idx * stride_seq_k + (head_id / num_kv_groups) * stride_h_k + k]; + } + smem_K[k * CTA_K + n] = val; + } + + // Load V to shared memory + for (uint32_t i = tid; i < CTA_K * HEAD_DIM; i += NUM_THREADS) { + uint32_t v_row = i / HEAD_DIM; + uint32_t v_col = i % HEAD_DIM; + uint32_t v_idx = k_start + v_row; + DTypeV val = Traits::from_float(0.0f); + if (v_idx < kv_len) { + val = V[batch_id * stride_bz_v + v_idx * stride_seq_v + (head_id / num_kv_groups) * stride_h_v + v_col]; + } + smem_V[v_row * HEAD_DIM + v_col] = val; + } + __syncthreads(); + + // Get K scale + float k_scale_val; + if constexpr (K_GRAN == QuantGranularity::kPerBlock) { + const uint32_t num_kv_heads = num_qo_heads / num_kv_groups; + k_scale_val = K_scale[batch_id * num_kv_heads * num_block_k + (head_id / num_kv_groups) * num_block_k + k_block_idx]; + } else if constexpr (K_GRAN == QuantGranularity::kPerWarp) { + const uint32_t num_kv_heads = num_qo_heads / num_kv_groups; + const uint32_t num_warp_block_k = num_block_k * NUM_WARPS_K; + k_scale_val = K_scale[batch_id * num_kv_heads * num_warp_block_k + (head_id / num_kv_groups) * num_warp_block_k + k_block_idx * NUM_WARPS_K + warp_idx_k]; + } + + float dequant_scale = q_scale_val * k_scale_val * sm_scale; + + // ============================================ + // Phase 1: Compute QK^T using rocWMMA INT8 + // RS[NUM_TILES_K][8] holds the QK results in registers + // ============================================ + + float RS[NUM_TILES_K][8]; // QK results for all K tiles + + #pragma unroll + for (uint32_t tile_k = 0; tile_k < NUM_TILES_K; tile_k++) { + FragAcc_QK acc_qk; + fill_fragment(acc_qk, 0); + + #pragma unroll + for (uint32_t k_iter = 0; k_iter < NUM_K_ITERS; k_iter++) { + FragA_QK frag_q; + load_matrix_sync(frag_q, smem_Q + k_iter * WMMA_K_INT8 * CTA_Q + q_tile_row, CTA_Q); + + FragB_QK frag_k; + load_matrix_sync(frag_k, smem_K + k_iter * WMMA_K_INT8 * CTA_K + tile_k * WMMA_N, CTA_K); + mma_sync(acc_qk, frag_q, frag_k, acc_qk); + } + + // Dequantize and apply masks directly in registers + #pragma unroll + for (uint32_t reg = 0; reg < 8; reg++) { + float val = static_cast(acc_qk.x[reg]) * dequant_scale; + + // Apply out-of-bounds and causal masks + uint32_t row = wmma_elem_row(reg, lane_id); + uint32_t col = wmma_elem_col(lane_id); + uint32_t q_idx = q_start + q_tile_row + row; + uint32_t k_idx = k_start + tile_k * WMMA_N + col; + + if (k_idx >= kv_len) val = -5000000.0f; + if constexpr (mask_mode == MaskMode::kCausal) { + if (k_idx > q_idx) val = -5000000.0f; + } + + RS[tile_k][reg] = val; + } + } + + // ============================================ + // Phase 2: Online softmax update (register-based) + // For each of the 8 rows this lane owns, find max across K tiles + // Then reduce across 16 lanes (shuffle within lane group 0-15 or 16-31) + // ============================================ + + #pragma unroll + for (uint32_t reg = 0; reg < 8; reg++) { + float m_prev = m[reg]; + + // Find max across all K tiles for this row element + float m_local = RS[0][reg]; + #pragma unroll + for (uint32_t tile_k = 1; tile_k < NUM_TILES_K; tile_k++) { + m_local = fmaxf(m_local, RS[tile_k][reg]); + } + + // Warp-level max reduction across 16 lanes (either 0-15 or 16-31) + // Only reduce within the same row-parity group + #pragma unroll + for (uint32_t offset = 8; offset > 0; offset /= 2) { + m_local = fmaxf(m_local, __shfl_xor(m_local, offset, WAVE_SIZE)); + } + + m[reg] = fmaxf(m_prev, m_local); + float o_scale = exp2f(m_prev - m[reg]); + + // Scale existing d and RO + d[reg] *= o_scale; + #pragma unroll + for (uint32_t fv = 0; fv < NUM_TILES_V; fv++) { + RO[fv][reg] *= o_scale; + } + + // Compute exp and accumulate to d + float local_sum = 0.0f; + #pragma unroll + for (uint32_t tile_k = 0; tile_k < NUM_TILES_K; tile_k++) { + RS[tile_k][reg] = exp2f(RS[tile_k][reg] - m[reg]); + local_sum += RS[tile_k][reg]; + } + + // Warp-level sum reduction across 16 lanes + #pragma unroll + for (uint32_t offset = 8; offset > 0; offset /= 2) { + local_sum += __shfl_xor(local_sum, offset, WAVE_SIZE); + } + + d[reg] += local_sum; + } + + // ============================================ + // Phase 3: Store S to shared memory for S@V + // Convert RS to DTypeV and store in row-major layout + // ============================================ + + #pragma unroll + for (uint32_t tile_k = 0; tile_k < NUM_TILES_K; tile_k++) { + #pragma unroll + for (uint32_t reg = 0; reg < 8; reg++) { + uint32_t row = wmma_elem_row(reg, lane_id); + uint32_t col = wmma_elem_col(lane_id); + uint32_t global_row = q_tile_row + row; + uint32_t global_col = tile_k * WMMA_N + col; + + smem_S[global_row * CTA_K + global_col] = Traits::from_float(RS[tile_k][reg]); + } + } + __syncthreads(); + + // ============================================ + // Phase 4: Compute S @ V using rocWMMA FP16 + // Accumulate into RO registers + // ============================================ + + #pragma unroll + for (uint32_t tile_v = 0; tile_v < NUM_TILES_V; tile_v++) { + FragAcc_SV_T acc_sv; + fill_fragment(acc_sv, 0.0f); + + #pragma unroll + for (uint32_t k_iter = 0; k_iter < NUM_SV_ITERS; k_iter++) { + FragA_SV_T frag_s; + // S is stored row-major: [CTA_Q, CTA_K] + load_matrix_sync(frag_s, smem_S + q_tile_row * CTA_K + k_iter * WMMA_K_FP16, CTA_K); + + FragB_SV_T frag_v; + load_matrix_sync(frag_v, smem_V + k_iter * WMMA_K_FP16 * HEAD_DIM + tile_v * WMMA_N, HEAD_DIM); + + mma_sync(acc_sv, frag_s, frag_v, acc_sv); + } + + // Accumulate to RO using direct element access + #pragma unroll + for (uint32_t reg = 0; reg < 8; reg++) { + RO[tile_v][reg] += acc_sv.x[reg]; + } + } + + __syncthreads(); + } + + // ============================================ + // Final: Normalize by d and write to output + // ============================================ + + // Write RO to output via shared memory + // Reuse smem_S as temp buffer + DTypeV* smem_out = smem_S; + + #pragma unroll + for (uint32_t tile_v = 0; tile_v < NUM_TILES_V; tile_v++) { + #pragma unroll + for (uint32_t reg = 0; reg < 8; reg++) { + uint32_t row = wmma_elem_row(reg, lane_id); + uint32_t col = wmma_elem_col(lane_id); + uint32_t global_row = q_tile_row + row; + uint32_t global_col = tile_v * WMMA_N + col; + + float inv_d = 1.0f / d[reg]; + float out_val = RO[tile_v][reg] * inv_d; + + smem_out[global_row * HEAD_DIM + global_col] = Traits::from_float(out_val); + } + } + __syncthreads(); + + // Copy from smem to global memory + for (uint32_t i = tid; i < CTA_Q * HEAD_DIM; i += NUM_THREADS) { + uint32_t row = i / HEAD_DIM; + uint32_t col = i % HEAD_DIM; + uint32_t o_idx = q_start + row; + + if (o_idx < qo_len) { + O[batch_id * stride_bz_o + o_idx * stride_seq_o + head_id * stride_h_o + col] = smem_out[row * HEAD_DIM + col]; + } + } +} + +// Kernel launcher +template +void SpargeAttentionROCmF16Dispatched( + int8_t* Q, int8_t* K, DTypeV* V, DTypeV* O, + int32_t* PV_Count, int32_t* Lut, int32_t* Valid_Block_Num, float* PV_Threshold, + float* Q_scale, float* K_scale, + const uint32_t batch_size, const uint32_t qo_len, const uint32_t kv_len, + const uint32_t num_qo_heads, const uint32_t num_kv_heads, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_seq_v, const uint32_t stride_h_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + constexpr QuantGranularity Q_GRAN = (qk_quant_gran == 1) ? QuantGranularity::kPerBlock : QuantGranularity::kPerWarp; + constexpr QuantGranularity K_GRAN = Q_GRAN; + constexpr MaskMode mask_mode = is_causal ? MaskMode::kCausal : MaskMode::kNone; + + const uint32_t num_kv_groups = num_qo_heads / num_kv_heads; + const uint32_t num_block_q = div_ceil_hip(qo_len, CTA_Q); + + // Calculate shared memory size (reduced - no smem_O, no smem_QK_float!) + // smem_Q: HEAD_DIM * CTA_Q (int8) + // smem_K: HEAD_DIM * CTA_K (int8) + // smem_V: CTA_K * HEAD_DIM (DTypeV) + // smem_S: CTA_Q * CTA_K (DTypeV) - also used as output buffer + + size_t smem_size = HEAD_DIM * CTA_Q * sizeof(int8_t) + // smem_Q + HEAD_DIM * CTA_K * sizeof(int8_t) + // smem_K + CTA_K * HEAD_DIM * sizeof(DTypeV) + // smem_V + max(CTA_Q * CTA_K, CTA_Q * HEAD_DIM) * sizeof(DTypeV); // smem_S / smem_out + + constexpr uint32_t NUM_WARPS = CTA_Q / WARP_Q; + constexpr uint32_t NUM_THREADS = NUM_WARPS * gfx11Params::WAVE_SIZE; + + dim3 grid(num_block_q, num_qo_heads, batch_size); + dim3 block(NUM_THREADS, 1, 1); + + hipLaunchKernelGGL((qk_int_sv_f16_block_sparse_attn_kernel_rocm), + grid, block, smem_size, 0, + Q, K, V, O, PV_Count, Lut, Valid_Block_Num, PV_Threshold, Q_scale, K_scale, + qo_len, kv_len, num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_seq_v, stride_h_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); +} + +// Explicit template instantiations +// qk_quant_gran=1 (kPerBlock) +// CTA_Q=64, HEAD_DIM=64, half +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 1, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 1, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=64, HEAD_DIM=64, bfloat16 +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 1, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 1, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=32, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 1, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 1, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=32, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 1, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 1, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=2 (kPerWarp) +// CTA_Q=64, HEAD_DIM=64, half +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 2, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 2, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=64, HEAD_DIM=64, bfloat16 +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 2, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<64, 64, 16, 64, 64, 2, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=32, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 2, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 2, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// CTA_Q=32, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 2, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 64, 16, 64, 128, 2, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// ======================================== +// CTA_K=32 instantiations for reduced register pressure (HEAD_DIM=128 only) +// ======================================== + +// qk_quant_gran=1 (kPerBlock), CTA_Q=32, CTA_K=32, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 1, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 1, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=1 (kPerBlock), CTA_Q=32, CTA_K=32, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 1, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 1, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=2 (kPerWarp), CTA_Q=32, CTA_K=32, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 2, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 2, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=2 (kPerWarp), CTA_Q=32, CTA_K=32, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 2, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 32, 16, 32, 128, 2, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// ======================================== +// CTA_K=16 instantiations (even finer granularity) +// ======================================== + +// qk_quant_gran=1 (kPerBlock), CTA_Q=32, CTA_K=16, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 1, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 1, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=1 (kPerBlock), CTA_Q=32, CTA_K=16, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 1, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 1, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=2 (kPerWarp), CTA_Q=32, CTA_K=16, HEAD_DIM=128, half +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 2, true, 0, half, true, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 2, true, 0, half, false, false>( + int8_t*, int8_t*, half*, half*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +// qk_quant_gran=2 (kPerWarp), CTA_Q=32, CTA_K=16, HEAD_DIM=128, bfloat16 +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 2, true, 0, hip_bfloat16, true, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); + +template void SpargeAttentionROCmF16Dispatched<32, 16, 16, 16, 128, 2, true, 0, hip_bfloat16, false, false>( + int8_t*, int8_t*, hip_bfloat16*, hip_bfloat16*, int32_t*, int32_t*, int32_t*, float*, float*, float*, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float); diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.h new file mode 100644 index 0000000..d63d19c --- /dev/null +++ b/csrc/reduction_utils.h @@ -0,0 +1,163 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +__device__ __forceinline__ unsigned full_mask() { + return 0xFFFFFFFFu; +} + +template +__device__ __forceinline__ T shfl_xor(T v, int laneMask) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + return __shfl_xor(v, laneMask, warpSize); +#else + return __shfl_xor_sync(full_mask(), v, laneMask, warpSize); +#endif +} + +template +__device__ __forceinline__ T warpReduceSum(T val) { + + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val += shfl_xor(val, offset); + } + return val; +} + +template +__device__ __forceinline__ void warpReduceSumV2(T* val) { +#pragma unroll + for (int i = 0; i < NUM; ++i) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val[i] += shfl_xor(val[i], offset); + } + } +} + +template +__device__ __forceinline__ T blockReduceSum(T val) { + __shared__ T shared[64]; // 64 >= max(warpSize) for NV(32)/AMD(64) + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceSum(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ void blockReduceSumV2(T* val) { + + __shared__ T shared[NUM][65]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; ++i) shared[i][wid] = val[i]; + } + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; +#pragma unroll + for (int i = 0; i < NUM; ++i) { + T tmp = (lane < numWarps) ? shared[i][lane] : T(0); + val[i] = tmp; + } + warpReduceSumV2(val); +} + +template +__device__ __forceinline__ T warpReduceMax(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val > other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T warpReduceMin(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val < other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMin(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMin(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(1e20); + agg = warpReduceMin(agg); + return agg; +} + +} // namespace vllm diff --git a/csrc/reduction_utils_hip.h b/csrc/reduction_utils_hip.h new file mode 100644 index 0000000..f44a164 --- /dev/null +++ b/csrc/reduction_utils_hip.h @@ -0,0 +1,164 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include +#include +#include + +namespace vllm { + +__device__ __forceinline__ unsigned full_mask() { + return 0xFFFFFFFFu; +} + +template +__device__ __forceinline__ T shfl_xor(T v, int laneMask) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + return __shfl_xor(v, laneMask, warpSize); +#else + return __shfl_xor_sync(full_mask(), v, laneMask, warpSize); +#endif +} + +template +__device__ __forceinline__ T warpReduceSum(T val) { + + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val += shfl_xor(val, offset); + } + return val; +} + +template +__device__ __forceinline__ void warpReduceSumV2(T* val) { +#pragma unroll + for (int i = 0; i < NUM; ++i) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val[i] += shfl_xor(val[i], offset); + } + } +} + +template +__device__ __forceinline__ T blockReduceSum(T val) { + __shared__ T shared[64]; // 64 >= max(warpSize) for NV(32)/AMD(64) + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceSum(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ void blockReduceSumV2(T* val) { + + __shared__ T shared[NUM][65]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; ++i) shared[i][wid] = val[i]; + } + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; +#pragma unroll + for (int i = 0; i < NUM; ++i) { + T tmp = (lane < numWarps) ? shared[i][lane] : T(0); + val[i] = tmp; + } + warpReduceSumV2(val); +} + +template +__device__ __forceinline__ T warpReduceMax(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val > other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T warpReduceMin(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val < other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMin(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMin(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(1e20); + agg = warpReduceMin(agg); + return agg; +} + +} // namespace vllm diff --git a/csrc/utils.cuh b/csrc/utils.cuh new file mode 100644 index 0000000..ca83d2c --- /dev/null +++ b/csrc/utils.cuh @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.is_cuda(), "Tensor " #x " must be on CUDA") +#define CHECK_DTYPE(x, true_dtype) \ + TORCH_CHECK(x.dtype() == true_dtype, \ + "Tensor " #x " must have dtype (" #true_dtype ")") +#define CHECK_DIMS(x, true_dim) \ + TORCH_CHECK(x.dim() == true_dim, \ + "Tensor " #x " must have dimension number (" #true_dim ")") +#define CHECK_NUMEL(x, minimum) \ + TORCH_CHECK(x.numel() >= minimum, \ + "Tensor " #x " must have at last " #minimum " elements") +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + "Tensor " #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), "Tensor " #x " must be contiguous") +#define CHECK_LASTDIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.stride(-1) == 1, \ + "Tensor " #x " must be contiguous at the last dimension") \ No newline at end of file diff --git a/setup.py b/setup.py index 3e6b13b..ceaa82d 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ """ import os +import sys from pathlib import Path import subprocess from packaging.version import parse, Version @@ -25,9 +26,49 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +THIRD_PARTY_DIR = os.path.join(THIS_DIR, "third_party") + HAS_SM90 = False SAGE2PP_ENABLED = True +# Check for ROCm +IS_ROCM = torch.version.hip is not None + +def clone_rocwmma(): + """Clone rocWMMA v2 from rocm-libraries repo to third_party directory.""" + rocm_libs_dir = os.path.join(THIRD_PARTY_DIR, "rocm-libraries") + rocwmma_include = None + + if os.path.exists(rocm_libs_dir): + print(f"rocm-libraries already exists at {rocm_libs_dir}") + rocwmma_include = os.path.join(rocm_libs_dir, "projects", "rocwmma", "library", "include") + else: + print("Cloning rocWMMA v2 from rocm-libraries...") + os.makedirs(THIRD_PARTY_DIR, exist_ok=True) + + # Use sparse checkout to only get rocwmma + clone_cmds = [ + f'git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git "{rocm_libs_dir}"', + f'cd "{rocm_libs_dir}" && git sparse-checkout set projects/rocwmma' + ] + + for cmd in clone_cmds: + ret = os.system(cmd) + if ret != 0: + print(f"Warning: Failed to execute: {cmd}") + return None + + rocwmma_include = os.path.join(rocm_libs_dir, "projects", "rocwmma", "library", "include") + + if rocwmma_include and os.path.exists(rocwmma_include): + print(f"rocWMMA v2 include path: {rocwmma_include}") + return rocwmma_include + else: + print("Warning: rocWMMA include path not found") + return None + + def run_instantiations(src_dir: str): base_path = Path(src_dir) py_files = [ @@ -39,6 +80,7 @@ def run_instantiations(src_dir: str): print(f"Running: {py_file}") os.system(f"python {py_file}") + def get_instantiations(src_dir: str): # get all .cu files under src_dir base_path = Path(src_dir) @@ -48,161 +90,244 @@ def get_instantiations(src_dir: str): if path.is_file() and path.suffix == ".cu" ] + # Supported NVIDIA GPU architectures. SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0"} -# Compiler flags. -CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] -NVCC_FLAGS = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--use_fast_math", - "--threads=8", - "-Xptxas=-v", - "-diag-suppress=174", # suppress the specific warning - "-Xcompiler", "-include,cassert", # fix error occurs when compiling for SM90+ with newer CUDA toolkits -] - ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 -CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") - -def get_nvcc_cuda_version(cuda_dir: str) -> Version: - """Get the CUDA version from nvcc. - - Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py - """ - nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) - output = nvcc_output.split() - release_idx = output.index("release") + 1 - nvcc_cuda_version = parse(output[release_idx].split(",")[0]) - return nvcc_cuda_version - -def get_torch_arch_list() -> Set[str]: - # TORCH_CUDA_ARCH_LIST can have one or more architectures, - # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the - # compiler to additionally include PTX code that can be runtime-compiled - # and executed on the 8.6 or newer architectures. While the PTX code will - # not give the best performance on the newer architectures, it provides - # forward compatibility. - env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) - if env_arch_list is None: - return set() - - # List are separated by ; or space. - torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) - if not torch_arch_list: - return set() - - # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) - arch_list = torch_arch_list.intersection(valid_archs) - # If none of the specified architectures are valid, raise an error. - if not arch_list: + +ext_modules = [] +cmdclass = {} + +if IS_ROCM: + print("Building for ROCm (AMD GPUs)") + + # Get ROCm architecture + def get_rocm_arch(): + try: + if torch.cuda.is_available(): + props = torch.cuda.get_device_properties(0) + return props.gcnArchName.split(':')[0] + except: + pass + rocm_arch = os.environ.get("ROCM_ARCH", None) + if rocm_arch: + return rocm_arch + return "gfx1100" # default + + rocm_arch = get_rocm_arch() + print(f"Detected ROCm architecture: {rocm_arch}") + + # Clone rocWMMA v2 + rocwmma_include = clone_rocwmma() + + # Compiler flags for ROCm + debug = os.environ.get("SA_DEBUG", "0") == "1" + base_flags = ["-std=c++17", f"-D_GLIBCXX_USE_CXX11_ABI={ABI}", "-DUSE_ROCM=1", + "-U__HIP_NO_HALF_CONVERSIONS__"] + debug_flags = ["-O0", "-g3", "-ggdb", "-fno-inline", "-fno-omit-frame-pointer"] if debug else ["-O3"] + + rocm_hipcc = base_flags + debug_flags + [f"--offload-arch={rocm_arch}"] + + # Add architecture-specific defines for our kernel code + if rocm_arch.startswith("gfx9"): + rocm_hipcc.append("-DSA_ARCH_MI_SERIES=1") + elif rocm_arch.startswith("gfx10") or rocm_arch.startswith("gfx11"): + rocm_hipcc.append("-DSA_ARCH_RDNA_SERIES=1") + + # Windows-specific: avoid GPU RDC which causes linker issues + if sys.platform == "win32": + rocm_hipcc.append("-fno-gpu-rdc") + + rocm_hipcc.append(f"-D__ROCM_ARCH_{rocm_arch.upper()}") + + rocm_cxx = base_flags + debug_flags + + include_dirs = [] + if rocwmma_include: + include_dirs.append(rocwmma_include) + print(f"Using rocWMMA v2 from: {rocwmma_include}") + else: + print("Warning: Using system rocWMMA (may be older v1.x version)") + + is_mi_series = rocm_arch.startswith("gfx9") + + ext_kwargs = { + "extra_compile_args": {"cxx": rocm_cxx, "nvcc": rocm_hipcc}, + } + if include_dirs: + ext_kwargs["include_dirs"] = include_dirs + + # Build qattn_rocm extension with FP16 kernels (all archs) + qattn_sources = [ + "csrc/qattn/rocm/pybind_rocm.cpp", + "csrc/qattn/rocm/sgattn_f16.cu", + "csrc/qattn/rocm/launch_sgattn_f16.cu", + ] + + if is_mi_series: + qattn_sources.extend([ + "csrc/qattn/rocm/launch_sgattn.cu", + "csrc/qattn/rocm/sgattn.cu" + ]) + print(f"Building _qattn with FP8+FP16 for MI-series GPU ({rocm_arch})") + else: + print(f"Building _qattn with FP16 only for RDNA GPU ({rocm_arch})") + + ext_modules.append( + CUDAExtension( + "spas_sage_attn._qattn", + sources=qattn_sources, + **ext_kwargs + ) + ) + + # Build fused extension + ext_modules.append( + CUDAExtension( + "spas_sage_attn._fused", + sources=["csrc/fused/rocm/pybind_rocm.cpp", "csrc/fused/rocm/fused.cu"], + **ext_kwargs + ) + ) + + cmdclass = {"build_ext": BuildExtension} + +else: + # NVIDIA CUDA build + # Compiler flags. + CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] + NVCC_FLAGS = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--use_fast_math", + "--threads=8", + "-Xptxas=-v", + "-diag-suppress=174", # suppress the specific warning + "-Xcompiler", "-include,cassert", # fix error occurs when compiling for SM90+ with newer CUDA toolkits + ] + + CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + + if CUDA_HOME is None: raise RuntimeError( - "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " - f"variable ({env_arch_list}) is supported. " - f"Supported CUDA architectures are: {valid_archs}.") - invalid_arch_list = torch_arch_list - valid_archs - if invalid_arch_list: - warnings.warn( - f"Unsupported CUDA architectures ({invalid_arch_list}) are " - "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " - f"({env_arch_list}). Supported CUDA architectures are: " - f"{valid_archs}.") - return arch_list - -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 8: + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: + nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + def get_torch_arch_list() -> Set[str]: + env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if env_arch_list is None: + return set() + + torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) + if not torch_arch_list: + return set() + + valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) + arch_list = torch_arch_list.intersection(valid_archs) + if not arch_list: raise RuntimeError( - "GPUs with compute capability below 8.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("12.0"): - raise RuntimeError("CUDA 12.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("12.4"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 12.4 or higher is required for compute capability 8.9.") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 12.4 or higher is required for compute capability 9.0.") -if nvcc_cuda_version < Version("12.8"): - warnings.warn("CUDA 12.8 or higher is required for Sage2++") - SAGE2PP_ENABLED = False - -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability.replace(".", "") - if num == '90': - num = '90a' - HAS_SM90 = True - CXX_FLAGS += ["-DHAS_SM90"] - if num == '80' or num == '86' or num == '87': + "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " + f"variable ({env_arch_list}) is supported. " + f"Supported CUDA architectures are: {valid_archs}.") + invalid_arch_list = torch_arch_list - valid_archs + if invalid_arch_list: + warnings.warn( + f"Unsupported CUDA architectures ({invalid_arch_list}) are " + "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " + f"({env_arch_list}). Supported CUDA architectures are: " + f"{valid_archs}.") + return arch_list + + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + raise RuntimeError( + "GPUs with compute capability below 8.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + if not compute_capabilities: + raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.") + + if nvcc_cuda_version < Version("12.0"): + raise RuntimeError("CUDA 12.0 or higher is required to build the package.") + if nvcc_cuda_version < Version("12.4"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 12.4 or higher is required for compute capability 8.9.") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 12.4 or higher is required for compute capability 9.0.") + if nvcc_cuda_version < Version("12.8"): + warnings.warn("CUDA 12.8 or higher is required for Sage2++") SAGE2PP_ENABLED = False - - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -if SAGE2PP_ENABLED: - CXX_FLAGS += ["-DSAGE2PP_ENABLED"] + for capability in compute_capabilities: + num = capability.replace(".", "") + if num == '90': + num = '90a' + HAS_SM90 = True + CXX_FLAGS += ["-DHAS_SM90"] + if num == '80' or num == '86' or num == '87': + SAGE2PP_ENABLED = False + + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -ext_modules = [] + if SAGE2PP_ENABLED: + CXX_FLAGS += ["-DSAGE2PP_ENABLED"] -run_instantiations("csrc/qattn/instantiations_sm80") -run_instantiations("csrc/qattn/instantiations_sm89") -run_instantiations("csrc/qattn/instantiations_sm90") - -sources = [ - "csrc/qattn/pybind.cpp", - "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", - "csrc/qattn/qk_int_sv_f8_cuda_sm89.cu", -] + get_instantiations("csrc/qattn/instantiations_sm80") + get_instantiations("csrc/qattn/instantiations_sm89") - -if HAS_SM90: - sources += ["csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", ] - sources += get_instantiations("csrc/qattn/instantiations_sm90") - -qattn_extension = CUDAExtension( - name="spas_sage_attn._qattn", - sources=sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, - extra_link_args=['-lcuda'], -) -ext_modules.append(qattn_extension) - -fused_extension = CUDAExtension( - name="spas_sage_attn._fused", - sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(fused_extension) + run_instantiations("csrc/qattn/instantiations_sm80") + run_instantiations("csrc/qattn/instantiations_sm89") + run_instantiations("csrc/qattn/instantiations_sm90") + + sources = [ + "csrc/qattn/pybind.cpp", + "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", + "csrc/qattn/qk_int_sv_f8_cuda_sm89.cu", + ] + get_instantiations("csrc/qattn/instantiations_sm80") + get_instantiations("csrc/qattn/instantiations_sm89") + + if HAS_SM90: + sources += ["csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", ] + sources += get_instantiations("csrc/qattn/instantiations_sm90") + + qattn_extension = CUDAExtension( + name="spas_sage_attn._qattn", + sources=sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + extra_link_args=['-lcuda'], + ) + ext_modules.append(qattn_extension) + + fused_extension = CUDAExtension( + name="spas_sage_attn._fused", + sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) + ext_modules.append(fused_extension) + + cmdclass = {"build_ext": BuildExtension} setup( name='spas_sage_attn', @@ -228,5 +353,5 @@ def get_torch_arch_list() -> Set[str]: 'Operating System :: OS Independent', ], ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, + cmdclass=cmdclass, ) diff --git a/spas_sage_attn/core.py b/spas_sage_attn/core.py index 01eda7a..b38712e 100644 --- a/spas_sage_attn/core.py +++ b/spas_sage_attn/core.py @@ -29,19 +29,74 @@ print("Warning: Sage2++ NOT enabled") SAGE2PP_ENABLED = False +# Detect ROCm +IS_ROCM = hasattr(torch.version, 'hip') and torch.version.hip is not None + +def get_gpu_arch_info(): + """ + Get GPU architecture info that works on both CUDA and ROCm. + Returns: (arch_string, is_rocm, is_rdna, is_mi_series) + """ + if not torch.cuda.is_available(): + return None, False, False, False + + props = torch.cuda.get_device_properties(0) + if IS_ROCM: + # ROCm: use gcnArchName + arch = props.gcnArchName.split(':')[0] if hasattr(props, 'gcnArchName') else "gfx1100" + is_rdna = arch.startswith("gfx10") or arch.startswith("gfx11") + is_mi_series = arch.startswith("gfx9") + return arch, True, is_rdna, is_mi_series + else: + # CUDA: compute capability + major, minor = torch.cuda.get_device_capability(0) + return f"sm{major}{minor}", False, False, False + def get_cuda_arch_versions(): - cuda_archs = [] + """ + Get architecture strings for all GPUs. + On CUDA: returns ["sm80", "sm90", etc.] + On ROCm: returns ["gfx1100", "gfx1151", etc.] + """ + archs = [] for i in range(torch.cuda.device_count()): - major, minor = torch.cuda.get_device_capability(i) - cuda_archs.append(f"sm{major}{minor}") - return cuda_archs + if IS_ROCM: + props = torch.cuda.get_device_properties(i) + arch = props.gcnArchName.split(':')[0] if hasattr(props, 'gcnArchName') else "gfx1100" + archs.append(arch) + else: + major, minor = torch.cuda.get_device_capability(i) + archs.append(f"sm{major}{minor}") + return archs + +def is_sm90_or_equivalent(arch): + """Check if architecture supports SM90-level features (wgmma, etc.)""" + if IS_ROCM: + # On ROCm, MI series (gfx9xx) has similar capabilities + return arch.startswith("gfx9") + else: + return arch == "sm90" + +def supports_fp8(arch): + """Check if architecture supports FP8.""" + if IS_ROCM: + # Only MI series supports FP8 on ROCm + return arch.startswith("gfx9") + else: + # CUDA: sm89+ supports FP8 + try: + sm_num = int(arch[2:]) + return sm_num >= 89 + except: + return False @torch.compiler.disable def spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, smooth_k=True, simthreshd1=0.6, cdfthreshd=0.98, pvthreshd=50, attention_sink=False, tensor_layout="HND", output_dtype=torch.float16, return_sparsity=False): assert tensor_layout in ['HND', 'NHD'] if tensor_layout == 'NHD': q, k, v = map(lambda t: rearrange(t, '... L H D -> ... H L D'), (q, k, v)) - assert q.size(-2)>=128, "seq_len should be not less than 128." + min_seq = 64 if IS_ROCM else 128 + assert q.size(-2)>=min_seq, f"seq_len should be not less than {min_seq}." torch.cuda.set_device(v.device) dtype = q.dtype @@ -56,8 +111,20 @@ def spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_caus headdim = q.size(-1) arch = get_cuda_arch_versions()[q.device.index] - if arch == "sm90": + + # Check if this architecture supports FP8 (required for sage2) + if not supports_fp8(arch): + raise RuntimeError( + f"spas_sage2_attn_meansim_cuda requires FP8 support, but {arch} does not support FP8. " + f"On RDNA GPUs (gfx10xx/gfx11xx), use spas_sage_attn_meansim_cuda instead." + ) + + # Choose block sizes based on architecture + # For ROCm, use CTA_Q=64, CTA_K=64 to fit shared memory constraints + if is_sm90_or_equivalent(arch): lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink, BLKQ=64, BLKK=128) + elif IS_ROCM: + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink, BLKQ=64, BLKK=64) else: lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink, BLKQ=128, BLKK=64) @@ -69,28 +136,28 @@ def spas_sage2_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_caus pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device) o = torch.empty_like(q) + # sm80/sm86/sm87 don't support FP8, use FP16 kernel directly if arch in ("sm80", "sm86", "sm87"): qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0 ) else: - ## quant v + # sm89+ supports FP8, quantize V b, h_kv, kv_len, head_dim = v.shape padded_len = (kv_len + 127) // 128 * 128 v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1) v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device) v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) - #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1) fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1) - + if arch == "sm90": qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) elif SAGE2PP_ENABLED: qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) else: qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) - + if tensor_layout == 'NHD': o = rearrange(o, '... H L D -> ... L H D') if return_sparsity: @@ -107,7 +174,8 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is assert tensor_layout in ['HND', 'NHD'] if tensor_layout == 'NHD': q, k, v = map(lambda t: rearrange(t, '... L H D -> ... H L D'), (q, k, v)) - assert q.size(-2)>=128, "seq_len should be not less than 128." + min_seq = 64 if IS_ROCM else 128 + assert q.size(-2)>=min_seq, f"seq_len should be not less than {min_seq}." torch.cuda.set_device(v.device) dtype = q.dtype @@ -122,8 +190,18 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is headdim = q.size(-1) arch = get_cuda_arch_versions()[q.device.index] - if arch == "sm90": + + # Check if this architecture supports FP8 (required for sage2) + if not supports_fp8(arch): + raise RuntimeError( + f"spas_sage2_attn_meansim_topk_cuda requires FP8 support, but {arch} does not support FP8. " + f"On RDNA GPUs (gfx10xx/gfx11xx), use spas_sage_attn_meansim_topk_cuda instead." + ) + + if is_sm90_or_equivalent(arch): lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink, BLKQ=64, BLKK=128) + elif IS_ROCM: + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink, BLKQ=64, BLKK=64) else: lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink, BLKQ=128, BLKK=64) @@ -135,28 +213,28 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device) o = torch.empty_like(q) + # sm80/sm86/sm87 don't support FP8, use FP16 kernel directly if arch in ("sm80", "sm86", "sm87"): qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0 ) else: - ## quant v + # sm89+ supports FP8, quantize V b, h_kv, kv_len, head_dim = v.shape padded_len = (kv_len + 127) // 128 * 128 v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1) v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device) v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) - #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1) fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1) - + if arch == "sm90": qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) elif SAGE2PP_ENABLED: qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) else: qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) - + if tensor_layout == 'NHD': o = rearrange(o, '... H L D -> ... L H D') if return_sparsity: @@ -167,13 +245,14 @@ def spas_sage2_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is return o, qk_sparsity.item() else: return o - + @torch.compiler.disable def block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, dropout_p=0.0, scale=None, smooth_k=True, pvthreshd=50, attention_sink=False, tensor_layout="HND", output_dtype=torch.float16, return_sparsity=False): assert tensor_layout in ['HND', 'NHD'] if tensor_layout == 'NHD': q, k, v = map(lambda t: rearrange(t, '... L H D -> ... H L D'), (q, k, v)) - assert q.size(-2)>=128, "seq_len should be not less than 128." + min_seq = 64 if IS_ROCM else 128 + assert q.size(-2)>=min_seq, f"seq_len should be not less than {min_seq}." torch.cuda.set_device(v.device) dtype = q.dtype @@ -189,8 +268,17 @@ def block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, dropout_p=0.0, scale=Non arch = get_cuda_arch_versions()[q.device.index] - if arch == "sm90": + # Check if this architecture supports FP8 (required for sage2) + if not supports_fp8(arch): + raise RuntimeError( + f"block_sparse_sage2_attn_cuda requires FP8 support, but {arch} does not support FP8. " + f"On RDNA GPUs (gfx10xx/gfx11xx), use spas_sage_attn_meansim_cuda instead." + ) + + if is_sm90_or_equivalent(arch): q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 128) + elif IS_ROCM: + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 64) else: q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 128, 64) lut, valid_block_num = block_map_lut_triton(block_map=mask_id) @@ -202,28 +290,26 @@ def block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, dropout_p=0.0, scale=Non pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device) o = torch.empty_like(q) + # sm80/sm86/sm87 don't support FP8, use FP16 kernel directly if arch in ("sm80", "sm86", "sm87"): qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, False, 1, scale, 0 ) else: - ## quant v + # sm89+ supports FP8, quantize V b, h_kv, kv_len, head_dim = v.shape padded_len = (kv_len + 127) // 128 * 128 v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1) v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device) v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) - #fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1) - fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1) - + fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 448.0, 1) + if arch == "sm90": qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold_sm90(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) - elif SAGE2PP_ENABLED: - qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) else: qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(q_int8, k_int8, v_fp8, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, v_scale, 1, False, 1, scale, 0) - + if tensor_layout == 'NHD': o = rearrange(o, '... H L D -> ... L H D') if return_sparsity: @@ -237,7 +323,14 @@ def spas_sage_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_causa assert tensor_layout in ['HND', 'NHD'] if tensor_layout == 'NHD': q, k, v = map(lambda t: rearrange(t, '... L H D -> ... H L D'), (q, k, v)) - assert q.size(-2)>=128, "seq_len should be not less than 128." + + headdim = q.size(-1) + # min_seq depends on CTA_Q which depends on headdim for ROCm + if IS_ROCM: + min_seq = 32 if headdim == 128 else 64 + else: + min_seq = 128 + assert q.size(-2)>=min_seq, f"seq_len should be not less than {min_seq}." torch.cuda.set_device(v.device) dtype = q.dtype @@ -246,12 +339,20 @@ def spas_sage_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_causa else: q, k, v = q.contiguous().to(torch.bfloat16), k.contiguous().to(torch.bfloat16), v.contiguous().to(torch.float16) - if smooth_k: - km = k.mean(dim=-2, keepdim=True) - # k = k - km - headdim = q.size(-1) - - lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink) # + # Always compute km for block map generation + km = k.mean(dim=-2, keepdim=True) + # if smooth_k: + # k = k - km + + # For ROCm, use smaller tiles for headdim=128 to reduce register pressure: + # - headdim=64: CTA_Q=64, CTA_K=64 + # - headdim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + if IS_ROCM: + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink, BLKQ=blkq, BLKK=blkk) + else: + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, return_lut=True, attention_sink=attention_sink) if scale is None: scale = 1.0 / (headdim ** 0.5) @@ -261,8 +362,14 @@ def spas_sage_attn_meansim_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_causa pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device) _is_causal = 1 if is_causal else 0 - o = torch.empty_like(q) + o = torch.empty(q.shape, dtype=v.dtype, device=q.device) qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, _is_causal, 1, scale, 0) + # Sync to ensure kernel completes before subsequent operations (helps with ROCm resource management) + if IS_ROCM: + torch.cuda.synchronize() + # Convert output back to original dtype if needed + if o.dtype != dtype: + o = o.to(dtype) if tensor_layout == 'NHD': o = rearrange(o, '... H L D -> ... L H D') @@ -280,7 +387,14 @@ def spas_sage_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_ assert tensor_layout in ['HND', 'NHD'] if tensor_layout == 'NHD': q, k, v = map(lambda t: rearrange(t, '... L H D -> ... H L D'), (q, k, v)) - assert q.size(-2)>=128, "seq_len should be not less than 128." + + headdim = q.size(-1) + # min_seq depends on CTA_Q which depends on headdim for ROCm + if IS_ROCM: + min_seq = 32 if headdim == 128 else 64 + else: + min_seq = 128 + assert q.size(-2)>=min_seq, f"seq_len should be not less than {min_seq}." torch.cuda.set_device(v.device) dtype = q.dtype @@ -294,7 +408,15 @@ def spas_sage_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_ # k = k - km headdim = q.size(-1) - lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink) # + # For ROCm, use smaller tiles for headdim=128 to reduce register pressure: + # - headdim=64: CTA_Q=64, CTA_K=64 + # - headdim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + if IS_ROCM: + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink, BLKQ=blkq, BLKK=blkk) + else: + lut, valid_block_num, q_int8, q_scale, k_int8, k_scale = get_block_map_meansim_fuse_quant(q, k, km, is_causal=is_causal, simthreshd1=simthreshd1, cdfthreshd=cdfthreshd, topk=topk, return_lut=True, attention_sink=attention_sink) if scale is None: scale = 1.0 / (headdim ** 0.5) @@ -304,8 +426,14 @@ def spas_sage_attn_meansim_topk_cuda(q, k, v, attn_mask=None, dropout_p=0.0, is_ pvthreshd = hyperparameter_check(pvthreshd, q.size(-3), q.device) _is_causal = 1 if is_causal else 0 - o = torch.empty_like(q) + o = torch.empty(q.shape, dtype=v.dtype, device=q.device) qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(q_int8, k_int8, v, o, lut, valid_block_num, pvthreshd, q_scale, k_scale, 1, _is_causal, 1, scale, 0) + # Sync to ensure kernel completes before subsequent operations (helps with ROCm resource management) + if IS_ROCM: + torch.cuda.synchronize() + # Convert output back to original dtype if needed + if o.dtype != dtype: + o = o.to(dtype) if tensor_layout == 'NHD': o = rearrange(o, '... H L D -> ... L H D')