Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
Expand Down
4 changes: 2 additions & 2 deletions csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
major_capability == 12,
"Marlin W4A8-FP8 only support SM89 or SM12x device (It is slower than "
"Marlin W4A16 on other devices).");
}

Expand Down
6 changes: 3 additions & 3 deletions csrc/quantization/marlin/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM89 and the SM12x family (SM120 RTX 5090, SM121 DGX Spark GB10)
# fully support mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
if arch == 89 or arch // 10 == 12:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct cutlass_3x_gemm_sm120 {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct cutlass_3x_gemm_sm120_custom {
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, void>::CollectiveOp;

using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
using GemmKernel = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};

Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,8 @@ def is_invalid(
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
args = sub_case + (m, n, k) + case[4:]
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/quantization/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def is_invalid(
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
and not current_platform.is_device_capability(89)
and not current_platform.is_device_capability_family(120)
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,9 @@ def get_marlin_input_dtype(prefix: str | None = None):
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
) and not current_platform.is_device_capability_family(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"Marlin W4A8-FP8 only support SM89 or SM12x device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
Expand Down
Loading