From 774ae1054414aba9a7464e3b9768cd9f69eb9a66 Mon Sep 17 00:00:00 2001 From: Johnny Date: Tue, 22 Apr 2025 14:09:59 +0200 Subject: [PATCH 1/4] add more NVIDIA DEVICES Support --- setup.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 5e8e4e4..e5a8f9f 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,15 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +HAS_SM80 = False +HAS_SM86 = False +HAS_SM87 = False +HAS_SM89 = False HAS_SM90 = False +HAS_SM100 = False +HAS_SM101 = False +HAS_SM110 = False +HAS_SM120 = False def run_instantiations(src_dir: str): base_path = Path(src_dir) @@ -48,7 +56,7 @@ def get_instantiations(src_dir: str): ] # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0"} +SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "11.0", "12.0"} # Compiler flags. CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] @@ -147,11 +155,42 @@ def get_torch_arch_list() -> Set[str]: # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: - num = capability[0] + capability[2] - if num == '90': - num = '90a' + if capability.startswith("8.0"): + HAS_SM80 = True + num = "80" + CXX_FLAGS += ["-DHAS_SM80"] + elif capability.startswith("8.6"): + HAS_SM86 = True + num = "86" + CXX_FLAGS += ["-DHAS_SM86"] + elif capability.startswith("8.7"): + HAS_SM87 = True + num = "87" + CXX_FLAGS += ["-DHAS_SM87"] + elif capability.startswith("8.9"): + HAS_SM89 = True + num = "89" + CXX_FLAGS += ["-DHAS_SM89"] + elif capability.startswith("9.0"): HAS_SM90 = True + num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction. CXX_FLAGS += ["-DHAS_SM90"] + elif capability.startswith("10.0"): + HAS_SM100 = True + num = "100" + CXX_FLAGS += ["-DHAS_SM100"] + elif capability.startswith("10.1"): + HAS_SM101 = True + num = "101" + CXX_FLAGS += ["-DHAS_SM101"] + elif capability.startswith("11.0"): + HAS_SM110 = True + num = "110" + CXX_FLAGS += ["-DHAS_SM110"] + elif capability.startswith("12.0"): + HAS_SM120 = True + num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. + CXX_FLAGS += ["-DHAS_SM120"] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] From dc9ad75882fb19238b31bdec07b5565cb3692b58 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 7 May 2025 10:36:57 +0200 Subject: [PATCH 2/4] Update mma.cuh --- csrc/mma.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/mma.cuh b/csrc/mma.cuh index d9f6551..ef92245 100644 --- a/csrc/mma.cuh +++ b/csrc/mma.cuh @@ -22,6 +22,7 @@ #include #include #include +#include namespace mma{ @@ -49,8 +50,7 @@ namespace mma{ #if defined(__CUDA_ARCH__) #define RUNTIME_ASSERT(x) __brkpt() #else -#include -#define RUNTIME_ASSERT(x) assert(0 && x) +#define RUNTIME_ASSERT(x) printf("%s\n",x);exit(-1) #endif enum class MMAMode { From cf4c2f16655fc11b3a851c91e2ca079803300440 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 7 May 2025 10:37:31 +0200 Subject: [PATCH 3/4] Update numeric_conversion.cuh --- csrc/numeric_conversion.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/numeric_conversion.cuh b/csrc/numeric_conversion.cuh index e09a3d0..6f865a9 100644 --- a/csrc/numeric_conversion.cuh +++ b/csrc/numeric_conversion.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) @@ -32,8 +33,7 @@ #if defined(__CUDA_ARCH__) #define RUNTIME_ASSERT(x) __brkpt() #else -#include -#define RUNTIME_ASSERT(x) assert(0 && x) +#define RUNTIME_ASSERT(x) printf("%s\n",x);exit(-1) #endif __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) @@ -139,4 +139,4 @@ __device__ __forceinline__ int8_t float_to_int8_rn(float x) uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); -} \ No newline at end of file +} From a4b62d1de354ccd56a2a323055fb554ea1dc2ef0 Mon Sep 17 00:00:00 2001 From: Johnny Date: Wed, 7 May 2025 10:39:12 +0200 Subject: [PATCH 4/4] Update setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e5a8f9f..2690706 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,8 @@ def get_instantiations(src_dir: str): ] # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "11.0", "12.0"} - +# SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "11.0", "12.0"} +SUPPORTED_ARCHS = {"8.0", "8.6", "8.7", "8.9", "9.0"} # Compiler flags. CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] NVCC_FLAGS = [