Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,12 @@ def expects_unquantized_inputs(self) -> bool:

@staticmethod
def _supports_current_device() -> bool:
return current_platform.has_device_capability((10, 0))
p = current_platform
return p.is_cuda() and (
p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)

@staticmethod
def _supports_no_act_and_mul() -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def activation_format() -> mk.FusedMoEActivationFormat:

@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_device_capability_family(100)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)

@staticmethod
def _supports_no_act_and_mul() -> bool:
Expand Down
27 changes: 14 additions & 13 deletions vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ def expects_unquantized_inputs(self) -> bool:

@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return (
current_platform.is_cuda()
p.is_cuda()
and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
p.is_device_capability(90)
or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability_family(120)
)
and has_flashinfer_cutlass_fused_moe()
)
Expand All @@ -102,29 +105,27 @@ def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+

p = current_platform
scheme = (weight_key, activation_key)
# The following are supported by FlashInferExperts:
return (
# unquantized and fp8 static per-tensor on 9.0+
(
scheme
in [
(None, None),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
and p.has_device_capability(90)
)
# fp8 block-scale on 9.0
or (
(scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym))
and (p.is_device_capability((9, 0)))
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym)
and p.is_device_capability(90)
)
# nvfp4 on 10.0+
or (
(scheme == (kNvfp4Static, kNvfp4Dynamic))
and (p.is_device_capability_family(100))
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
# Add check flashinfer trtllm is available
return p.is_cuda() and p.is_device_capability_family(100)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand All @@ -24,10 +23,6 @@
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)

if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
Expand All @@ -38,8 +33,6 @@


__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
]

Expand Down Expand Up @@ -124,26 +117,6 @@ def _make_reason(reason: str) -> str:
return True, None


def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)


def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
)


def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down

This file was deleted.