diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 52f266707bb9..821daf68aad6 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -15,11 +15,11 @@ for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index e3f3b4175b92..e2055e0db0bb 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, "FP8 only support Ada Lovelace or newer GPUs."); TORCH_CHECK( major_capability * 10 + minor_capability == 89 || - major_capability * 10 + minor_capability == 120, - "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + major_capability == 12, + "Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than " "Marlin W4A16 on other devices)."); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 0c04f010888d..8b3aaf248d2a 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -37,6 +37,56 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; namespace vllm { +// Software E2M1 conversion for architectures without hardware +// cvt.rn.satfinite.e2m1x2.f32 (SM12x lacks this SM100-only instruction). +// Uses round-to-nearest-even (IEEE 754) to match hardware behavior: +// at midpoints, ties break to the value with an even integer code. +// E2M1 representable values and codes: +// 0.0(0) 0.5(1) 1.0(2) 1.5(3) 2.0(4) 3.0(5) 4.0(6) 6.0(7) +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 + +__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) { + uint8_t sign = (__float_as_uint(v) >> 31) & 1; + float av = fabsf(v); + uint8_t e2m1; + // Midpoint tie-breaking: <= rounds to lower (even) code, < rounds to upper. + if (av <= 0.25f) + e2m1 = 0; // 0.0; midpoint 0.25 → code 0 (even) + else if (av < 0.75f) + e2m1 = 1; // 0.5; midpoint 0.75 → code 2 (even, next branch) + else if (av <= 1.25f) + e2m1 = 2; // 1.0; midpoint 1.25 → code 2 (even) + else if (av < 1.75f) + e2m1 = 3; // 1.5; midpoint 1.75 → code 4 (even, next branch) + else if (av <= 2.5f) + e2m1 = 4; // 2.0; midpoint 2.5 → code 4 (even) + else if (av < 3.5f) + e2m1 = 5; // 3.0; midpoint 3.5 → code 6 (even, next branch) + else if (av <= 5.0f) + e2m1 = 6; // 4.0; midpoint 5.0 → code 6 (even) + else + e2m1 = 7; // 6.0 (satfinite) + return (sign << 3) | e2m1; +} + +// Pack two E2M1 values into one byte (matches cvt.rn.satfinite.e2m1x2.f32 +// layout: hi in upper nibble, lo in lower nibble). +__device__ __forceinline__ uint8_t sw_e2m1x2_from_f32(float hi, float lo) { + return (sw_float_to_e2m1(hi) << 4) | sw_float_to_e2m1(lo); +} + +// Pack 8 float values (as 4 float2) into a uint32_t of E2M1 values. +__device__ __forceinline__ uint32_t sw_fp32_vec8_to_e2m1(const float2* array) { + uint8_t b0 = sw_e2m1x2_from_f32(array[0].y, array[0].x); + uint8_t b1 = sw_e2m1x2_from_f32(array[1].y, array[1].x); + uint8_t b2 = sw_e2m1x2_from_f32(array[2].y, array[2].x); + uint8_t b3 = sw_e2m1x2_from_f32(array[3].y, array[3].x); + return (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | + ((uint32_t)b3 << 24); +} + +#endif // SM12x software E2M1 + template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, @@ -70,6 +120,9 @@ inline std::pair computeSwizzledSFShape(int64_t m, // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { uint32_t val; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 + val = sw_fp32_vec8_to_e2m1(reinterpret_cast(array)); +#else asm volatile( "{\n" ".reg .b8 byte0;\n" @@ -85,12 +138,16 @@ inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { : "=r"(val) : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); +#endif return val; } // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { uint32_t val; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 + val = sw_fp32_vec8_to_e2m1(array); +#else asm volatile( "{\n" ".reg .b8 byte0;\n" @@ -106,6 +163,7 @@ __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { : "=r"(val) : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); +#endif return val; } @@ -117,6 +175,10 @@ using fp4_packed_t = std::conditional_t; __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { u32x2 out; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 + out.lo = sw_fp32_vec8_to_e2m1(array); + out.hi = sw_fp32_vec8_to_e2m1(array + 4); +#else asm volatile( "{\n" ".reg .b8 b0;\n" @@ -143,6 +205,7 @@ __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); +#endif return out; } diff --git a/csrc/quantization/marlin/generate_kernels.py b/csrc/quantization/marlin/generate_kernels.py index 5ecbc6ac9990..d4b1d10f70e3 100644 --- a/csrc/quantization/marlin/generate_kernels.py +++ b/csrc/quantization/marlin/generate_kernels.py @@ -15,11 +15,11 @@ for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh index 546e1eec64bb..b90c7e29bf2b 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh @@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index 37846a87bbfb..8f9817d9052b 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -70,7 +70,7 @@ struct cutlass_3x_gemm_sm120_custom { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule, void>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 28be9f23d661..c88ce8799015 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -865,7 +865,8 @@ def is_invalid( for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue args = sub_case + (m, n, k) + case[4:] diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index f918212f763c..8b35fab81ef8 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -381,7 +381,8 @@ def is_invalid( for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue args = sub_case + (size_m, size_n, size_k) + case[4:] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index d659effd70ff..69bb47f71783 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -479,9 +479,9 @@ def get_marlin_input_dtype(prefix: str | None = None): elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8": if not current_platform.is_device_capability( 89 - ) and not current_platform.is_device_capability(120): + ) and not current_platform.is_device_capability_family(120): raise ValueError( - "Marlin W4A8-FP8 only support SM89 or SM120 device " + "Marlin W4A8-FP8 only support SM89 or SM12x device " "(It is slower than Marlin W4A16 on other devices). " "You can consider using W4A8-INT8 instead" "(set VLLM_MARLIN_INPUT_DTYPE=int8)."