Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 79 additions & 8 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,80 @@
__pycache__
spas_sage_attn.egg-info
*.pkl
/dist
/build
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so
.DS_Store
inst*.cu
/unit_test

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
*.manifest
*.spec

# pip
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# IDE
.idea/
.vscode/
.cursor/
*.swp
*.swo

# ROCm cloned libraries
/third_party/

# HIP generated files
*.hip

# Build artifacts
*.o
*.obj

# Instantiation generated files
csrc/qattn/instantiations_sm80/*.cu
csrc/qattn/instantiations_sm89/*.cu
csrc/qattn/instantiations_sm90/*.cu
185 changes: 185 additions & 0 deletions README_AMD_WINDOWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# SpargeAttn - AMD ROCm on Windows Setup Guide

This guide explains how to build and run SpargeAttn on Windows with AMD GPUs using ROCm.

> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues.

## Supported Hardware

SpargeAttn on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1103, gfx1151).

## Prerequisites

- Windows 10/11
- Python 3.11, 3.12, or 3.13
- Visual Studio 2022 with C++ build tools
- AMD Adrenaline driver (latest recommended)

## Installation

### 1. Install ROCm and PyTorch from TheRock

Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture.

#### Create a Virtual Environment

```powershell
python -m venv venv
.\venv\Scripts\Activate.ps1
```

#### Install PyTorch (includes ROCm SDK as dependency)

For **gfx1151** (AMD Strix Halo iGPU):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision
```

For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision
```

For **gfx120X** (RX 9060, RX 9070):
```powershell
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision
```

#### Initialize ROCm SDK

```powershell
rocm-sdk init
```

#### Install Triton with AMD Windows Support

```powershell
pip install triton-windows
```

### 2. Set Environment Variables

Open a PowerShell terminal and run:

```powershell
# Activate Visual Studio environment
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }

# Activate the virtual environment
.\venv\Scripts\Activate.ps1

# Set ROCm paths using rocm-sdk
$ROCM_ROOT = (rocm-sdk path --root).Trim()
$ROCM_BIN = (rocm-sdk path --bin).Trim()
$env:ROCM_HOME = $ROCM_ROOT
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"

# Set compiler and build settings
$env:CC = "clang-cl"
$env:CXX = "clang-cl"
$env:DISTUTILS_USE_SDK = "1"

# Enable experimental features
$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE"
$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1"
```

### 3. Build and Install SpargeAttn

```powershell
cd <path_to_spargeattn>
pip install --no-build-isolation -v .
```

## Testing

### Quick Correctness Test

Run this script to verify SpargeAttn is working correctly by comparing against PyTorch SDPA:

```python
import torch
import torch.nn.functional as F
from spas_sage_attn.core import spas_sage_attn_meansim_cuda

device = torch.device('cuda')

# Create random test tensors (use float16 for ROCm compatibility)
q = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)
k = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)
v = torch.randn(1, 12, 2048, 128, dtype=torch.float16, device=device)

# Compute reference output using PyTorch SDPA
with torch.no_grad():
sdpa = F.scaled_dot_product_attention(q.float(), k.float(), v.float()).to(torch.float16)

# Compute SpargeAttn output (with 100% sparsity = dense attention)
sparge = spas_sage_attn_meansim_cuda(
q, k, v,
is_causal=False,
smooth_k=False,
simthreshd1=0.0, # No similarity threshold (keep all blocks)
cdfthreshd=1.0, # 100% sparsity
pvthreshd=0,
tensor_layout='HND'
)

# Compare outputs using cosine similarity
cos = F.cosine_similarity(
sdpa.flatten().float().unsqueeze(0),
sparge.flatten().float().unsqueeze(0)
)
print(f'Cosine similarity: {cos.item():.6f}') # Should be ~0.9999
```

Save this as `test_spargeattn.py` and run:

```powershell
python test_spargeattn.py
```

Expected output:
```
Cosine similarity: 0.999900
```

A cosine similarity above 0.999 indicates the kernel is working correctly.

## Performance Notes

At L=4096, D=128, bf16 vs PyTorch SDPA (with aotriton):

| Sparsity | Time | Speedup vs SDPA |
|----------|------|-----------------|
| 100% | 33.0 ms | 0.18x |
| 50% | 13.7 ms | 0.43x |
| 25% | 7.4 ms | 0.79x |
| **10%** | **3.2 ms** | **1.81x** |
| 5% | 1.8 ms | 3.26x |
| 2% | 1.0 ms | 6.07x |

**Break-even point**: ~20-25% sparsity. Below that, SpargeAttn is faster than dense SDPA.

## Known Issues

1. **No FP8 support on RDNA3** - rocWMMA on gfx11xx doesn't support FP8, so FP16/BF16 is used for V.

2. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless.

## Troubleshooting

### "LoadLibrary failed" or "cannot find amdhip64.dll"

Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages.

### "LINK : fatal error LNK1104: cannot open file 'python312.lib'"

Ensure Visual Studio environment is activated before building:
```powershell
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }
```

### "PermissionError" when compiling Triton kernels

This is a known Windows issue with temp file handling. Make sure you're using the latest `triton-windows` package (`pip install --upgrade triton-windows`).

112 changes: 112 additions & 0 deletions csrc/fused/rocm/dispatch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2024 by SageAttention team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <torch/extension.h>
#include <cstdint>
#include <sstream>
#include <stdexcept>

#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
if (head_dim == 64) { \
constexpr int HEAD_DIM = 64; \
__VA_ARGS__ \
} else if (head_dim == 128) { \
constexpr int HEAD_DIM = 128; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported head dim: " << int(head_dim); \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \
if (is_causal == 1) { \
constexpr bool IS_CAUSAL = true; \
__VA_ARGS__ \
} else if (is_causal == 0) { \
constexpr bool IS_CAUSAL = false; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported causal mode: " << int(is_causal); \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \
if (qk_quant_gran == 2) { \
constexpr int QK_QUANT_GRAN = 2; \
__VA_ARGS__ \
} else if (qk_quant_gran == 3) { \
constexpr int QK_QUANT_GRAN = 3; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \
if (return_lse == 1) { \
constexpr bool RETURN_LSE = true; \
__VA_ARGS__ \
} else if (return_lse == 0) { \
constexpr bool RETURN_LSE = false; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported causal mode: " << int(return_lse); \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
if (pytorch_dtype == at::ScalarType::Half) { \
using c_type = half; \
__VA_ARGS__ \
} else if (pytorch_dtype == at::ScalarType::BFloat16) { \
using c_type = hip_bfloat16; \
__VA_ARGS__ \
} else { \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
}

#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \
if (block_size == 64) { \
constexpr int BLOCK_SIZE = 64; \
__VA_ARGS__ \
} else if (block_size == 128) { \
constexpr int BLOCK_SIZE = 128; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported block_size " << int(block_size); \
throw std::invalid_argument(err_msg.str()); \
}

#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \
if (warp_block_size == 16) { \
constexpr int WARP_BLOCK_SIZE = 16; \
__VA_ARGS__ \
} else if (warp_block_size == 32) { \
constexpr int WARP_BLOCK_SIZE = 32; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported warp_block_size " << int(warp_block_size); \
throw std::invalid_argument(err_msg.str()); \
}
Loading