diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh index 546e1eec64bb..b90c7e29bf2b 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh @@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index 245f5c10fcad..a2850c92c88a 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -72,7 +72,7 @@ struct cutlass_3x_gemm_sm120_custom { sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule, void>::CollectiveOp; - using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; }; diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 52f266707bb9..821daf68aad6 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -15,11 +15,11 @@ for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 60681ad930ff..61b54856ee07 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, "FP8 only support Ada Lovelace or newer GPUs."); TORCH_CHECK( major_capability * 10 + minor_capability == 89 || - major_capability * 10 + minor_capability == 120, - "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + major_capability == 12, + "Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than " "Marlin W4A16 on other devices)."); } diff --git a/csrc/quantization/marlin/generate_kernels.py b/csrc/quantization/marlin/generate_kernels.py index 5ecbc6ac9990..d4b1d10f70e3 100644 --- a/csrc/quantization/marlin/generate_kernels.py +++ b/csrc/quantization/marlin/generate_kernels.py @@ -15,11 +15,11 @@ for arch in sys.argv[1].split(","): arch = arch[: arch.index(".") + 2].replace(".", "") arch = int(arch) - # only SM89 and SM120 fully support - # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10) + # fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. # SM90 and SM100 can use this PTX, but it’s simulated # with FP16 MMA, so it cannot achieve any acceleration. - if arch in [89, 120]: + if arch == 89 or arch // 10 == 12: SUPPORT_FP8 = True if arch >= 80: SUPPORT_SM80 = True diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 4a3941b3d172..0788f5c786a7 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -865,7 +865,8 @@ def is_invalid( for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue args = sub_case + (m, n, k) + case[4:] @@ -1091,32 +1092,32 @@ def test_fused_marlin_moe( per_act_token_quant=True, ) - marlin_output = fused_marlin_moe( - a, - w1_data.qweight, - w2_data.qweight, - None, - None, - w1_data.scales, - w2_data.scales, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map, - global_scale1=w1_data.global_scale, - global_scale2=w2_data.global_scale, - g_idx1=w1_data.g_idx, - g_idx2=w2_data.g_idx, - input_global_scale1=w1_data.a_scales_factor, - input_global_scale2=w2_data.a_scales_factor, - sort_indices1=w1_data.sort_indices, - sort_indices2=w2_data.sort_indices, - w1_zeros=w1_data.zeros, - w2_zeros=w2_data.zeros, - input_dtype=a_dtype, - quant_type_id=b_type.id, - is_k_full=is_k_full, - ) + marlin_output = fused_marlin_moe( + a, + w1_data.qweight, + w2_data.qweight, + None, + None, + w1_data.scales, + w2_data.scales, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=e_map, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + input_global_scale1=w1_data.a_scales_factor, + input_global_scale2=w2_data.a_scales_factor, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, + input_dtype=a_dtype, + quant_type_id=b_type.id, + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) @@ -1166,29 +1167,29 @@ def test_fused_marlin_moe_with_bias(m): a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2 ) - marlin_output = fused_marlin_moe( - a, - w1_data.qweight, - w2_data.qweight, - w1_data.marlin_bias, - w2_data.marlin_bias, - w1_data.scales, - w2_data.scales, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=None, - global_scale1=w1_data.global_scale, - global_scale2=w2_data.global_scale, - g_idx1=w1_data.g_idx, - g_idx2=w2_data.g_idx, - sort_indices1=w1_data.sort_indices, - sort_indices2=w2_data.sort_indices, - w1_zeros=w1_data.zeros, - w2_zeros=w2_data.zeros, - quant_type_id=quant_type.id, - is_k_full=is_k_full, - ) + marlin_output = fused_marlin_moe( + a, + w1_data.qweight, + w2_data.qweight, + w1_data.marlin_bias, + w2_data.marlin_bias, + w1_data.scales, + w2_data.scales, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, + quant_type_id=quant_type.id, + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -1247,30 +1248,30 @@ def test_fused_marlin_moe_non_gated( activation=activation, ) - marlin_output = fused_marlin_moe( - a, - w1_data.qweight, - w2_data.qweight, - None, # bias1 - None, # bias2 - w1_data.scales, - w2_data.scales, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=None, - global_scale1=w1_data.global_scale, - global_scale2=w2_data.global_scale, - g_idx1=w1_data.g_idx, - g_idx2=w2_data.g_idx, - sort_indices1=w1_data.sort_indices, - sort_indices2=w2_data.sort_indices, - w1_zeros=w1_data.zeros, - w2_zeros=w2_data.zeros, - quant_type_id=quant_type.id, - is_k_full=is_k_full, - activation=activation, - ) + marlin_output = fused_marlin_moe( + a, + w1_data.qweight, + w2_data.qweight, + None, # bias1 + None, # bias2 + w1_data.scales, + w2_data.scales, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, + quant_type_id=quant_type.id, + is_k_full=is_k_full, + activation=activation, + ) torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index f918212f763c..8b35fab81ef8 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -381,7 +381,8 @@ def is_invalid( for sub_case in inner_combinations: if ( sub_case[0] == scalar_types.float8_e4m3fn - and current_platform.get_device_capability() not in [89, 120] + and not current_platform.is_device_capability(89) + and not current_platform.is_device_capability_family(120) ): continue args = sub_case + (size_m, size_n, size_k) + case[4:] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index d659effd70ff..69bb47f71783 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -479,9 +479,9 @@ def get_marlin_input_dtype(prefix: str | None = None): elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8": if not current_platform.is_device_capability( 89 - ) and not current_platform.is_device_capability(120): + ) and not current_platform.is_device_capability_family(120): raise ValueError( - "Marlin W4A8-FP8 only support SM89 or SM120 device " + "Marlin W4A8-FP8 only support SM89 or SM12x device " "(It is slower than Marlin W4A16 on other devices). " "You can consider using W4A8-INT8 instead" "(set VLLM_MARLIN_INPUT_DTYPE=int8)."