diff --git a/csrc/mma.cuh b/csrc/mma.cuh index d9f6551..ef92245 100644 --- a/csrc/mma.cuh +++ b/csrc/mma.cuh @@ -22,6 +22,7 @@ #include #include #include +#include namespace mma{ @@ -49,8 +50,7 @@ namespace mma{ #if defined(__CUDA_ARCH__) #define RUNTIME_ASSERT(x) __brkpt() #else -#include -#define RUNTIME_ASSERT(x) assert(0 && x) +#define RUNTIME_ASSERT(x) printf("%s\n",x);exit(-1) #endif enum class MMAMode { diff --git a/csrc/numeric_conversion.cuh b/csrc/numeric_conversion.cuh index e09a3d0..6f865a9 100644 --- a/csrc/numeric_conversion.cuh +++ b/csrc/numeric_conversion.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) @@ -32,8 +33,7 @@ #if defined(__CUDA_ARCH__) #define RUNTIME_ASSERT(x) __brkpt() #else -#include -#define RUNTIME_ASSERT(x) assert(0 && x) +#define RUNTIME_ASSERT(x) printf("%s\n",x);exit(-1) #endif __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) @@ -139,4 +139,4 @@ __device__ __forceinline__ int8_t float_to_int8_rn(float x) uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); -} \ No newline at end of file +} diff --git a/setup.py b/setup.py index 5e8e4e4..2690706 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,15 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +HAS_SM80 = False +HAS_SM86 = False +HAS_SM87 = False +HAS_SM89 = False HAS_SM90 = False +HAS_SM100 = False +HAS_SM101 = False +HAS_SM110 = False +HAS_SM120 = False def run_instantiations(src_dir: str): base_path = Path(src_dir) @@ -48,8 +56,8 @@ def get_instantiations(src_dir: str): ] # Supported NVIDIA GPU architectures. +# SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "11.0", "12.0"} SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0"} - # Compiler flags. CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] NVCC_FLAGS = [ @@ -147,11 +155,42 @@ def get_torch_arch_list() -> Set[str]: # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: - num = capability[0] + capability[2] - if num == '90': - num = '90a' + if capability.startswith("8.0"): + HAS_SM80 = True + num = "80" + CXX_FLAGS += ["-DHAS_SM80"] + elif capability.startswith("8.6"): + HAS_SM86 = True + num = "86" + CXX_FLAGS += ["-DHAS_SM86"] + elif capability.startswith("8.7"): + HAS_SM87 = True + num = "87" + CXX_FLAGS += ["-DHAS_SM87"] + elif capability.startswith("8.9"): + HAS_SM89 = True + num = "89" + CXX_FLAGS += ["-DHAS_SM89"] + elif capability.startswith("9.0"): HAS_SM90 = True + num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction. CXX_FLAGS += ["-DHAS_SM90"] + elif capability.startswith("10.0"): + HAS_SM100 = True + num = "100" + CXX_FLAGS += ["-DHAS_SM100"] + elif capability.startswith("10.1"): + HAS_SM101 = True + num = "101" + CXX_FLAGS += ["-DHAS_SM101"] + elif capability.startswith("11.0"): + HAS_SM110 = True + num = "110" + CXX_FLAGS += ["-DHAS_SM110"] + elif capability.startswith("12.0"): + HAS_SM120 = True + num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. + CXX_FLAGS += ["-DHAS_SM120"] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]