Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a):

for i in range(num_batches):
a_fp8, a_global_sf = input_to_float8(a[i])
a_global_sf = 1.0 / a_global_sf
if a_global_sf.numel() == 1:
a_global_sf = a_global_sf.view(1, 1)
a_quant.append(a_fp8)
a_scales.append(a_global_sf)

Expand All @@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a):
return result_a_quant, result_a_scales


def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
match_ratio = close.float().mean()
assert match_ratio >= percent, (
f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
)

mismatch_percent = 1.0 - match_ratio.item()
assert mismatch_percent <= 1 - percent, (
f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
f"{1 - percent:.4f}"
)


@dataclass
class TestData:
hidden_states: torch.Tensor
Expand All @@ -104,14 +119,16 @@ def make_moe_tensors_8bit(
is_gated = activation.is_gated

hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
w13 = (
torch.randn(
(e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
)
/ 10
)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10

# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
Expand All @@ -124,14 +141,16 @@ def make_moe_tensors_8bit(
layer.w2_input_scale = a2_scale
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
layer.activation = activation
# Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()

# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_gated:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_trtllm:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight
layer.w13_weight, layer.w2_weight, is_gated
)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
Expand Down Expand Up @@ -162,20 +181,24 @@ def make_moe_tensors_8bit(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
def test_flashinfer_per_tensor_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
activation: MoEActivation,
monkeypatch,
):
if not current_platform.has_device_capability(100):
pytest.skip("Test is only supported for sm >= 100")
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=True, activation=activation
)

score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
Expand All @@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation=MoEActivation.SILU,
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
Expand All @@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
apply_router_weight_on_input=True,
)

torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
check_accuracy(
ref_output=output,
actual_output=flashinfer_output,
atol=0.1,
rtol=0.85,
percent=0.925,
)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
Expand Down Expand Up @@ -320,8 +349,13 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
expert_map=None,
apply_router_weight_on_input=True,
)
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2

check_accuracy(
ref_output=output,
actual_output=flashinfer_cutlass_output,
atol=0.1,
rtol=0.85,
percent=0.925,
)


Expand Down
14 changes: 8 additions & 6 deletions vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _supports_current_device() -> bool:


def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nanotron-Mini)."""
return False
"""Supports non-gated MoE."""
return True


def _supports_quant_scheme(
Expand All @@ -52,8 +52,7 @@ def _supports_quant_scheme(


def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only."""
return activation == MoEActivation.SILU
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]


def _supports_routing_method(
Expand All @@ -74,6 +73,7 @@ def _supports_routing_method(
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
Expand Down Expand Up @@ -291,6 +291,7 @@ def fi_trtllm_fp8_per_tensor_moe(
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
Expand Down Expand Up @@ -326,9 +327,9 @@ def fi_trtllm_fp8_per_tensor_moe(
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
routing_method_type=routing_method_type,
# TODO: Required for flashinfer==0.6.3, remove with update
# TODO: enum type Required for flashinfer==0.6.3, remove with update
# https://github.com/flashinfer-ai/flashinfer/pull/2508
activation_type=ActivationType.Swiglu,
activation_type=ActivationType(activation_type),
)


Expand All @@ -351,6 +352,7 @@ def fi_trtllm_fp8_per_tensor_moe_fake(
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
activation_type: int,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,10 +937,11 @@ def apply_monolithic(
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
assert not layer.renormalize
Comment on lines 939 to -943
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we need to update the compressed tensors side too, can do in followup PR

return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
Expand Down Expand Up @@ -50,8 +54,8 @@ def _supports_current_device() -> bool:


def _supports_no_act_and_mul() -> bool:
"""Does not support non-gated MoE (i.e. Nemotron-Nano)."""
return False
"""Supports non-gated MoE."""
return True


def _supports_quant_scheme(
Expand All @@ -66,8 +70,7 @@ def _supports_quant_scheme(


def _supports_activation(activation: MoEActivation) -> bool:
"""Supports silu activation only."""
return activation in [MoEActivation.SILU]
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]


def _supports_routing_method(
Expand Down Expand Up @@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
hidden_size,
intermediate_size,
num_experts,
is_gated_activation: bool,
):
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import (
Expand All @@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe(
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
gemm1_intermediate_size = (
2 * intermediate_size if is_gated_activation else intermediate_size
)

# Convert quantized weights to proper formats
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2
num_experts, gemm1_intermediate_size, hidden_size // 2
) # packed fp4
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, 2 * intermediate_size, hidden_size // 16
num_experts, gemm1_intermediate_size, hidden_size // 16
) # fp8 scaling factors

gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
Expand All @@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
_cache_permute_indices,
gemm1_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
is_gated_act_gemm=is_gated_activation,
)
gemm1_weights_fp4_shuffled.append(
gemm1_weights_fp4[i]
Expand All @@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
gemm1_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
is_gated_act_gemm=is_gated_activation,
)
gemm1_scales_fp4_shuffled.append(
nvfp4_block_scale_interleave(
Expand Down Expand Up @@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
gemm1_scales_fp4_shuffled = (
torch.stack(gemm1_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
.reshape(num_experts, gemm1_intermediate_size, hidden_size // 16)
)

gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
Expand Down Expand Up @@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe(

from vllm.model_executor.models.llama4 import Llama4MoE

# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
f"{activation} found instead."
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {activation} found instead."
)

# Quantize input to FP4
Expand All @@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe(
else router_logits
)

# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)

# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
Expand Down Expand Up @@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe(
routed_scaling_factor=None,
routing_method_type=routing_method_type,
do_finalize=True,
activation_type=activation_type,
)[0]

return out
Expand Down Expand Up @@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
]

# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.
if is_act_and_mul and backend in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]:
is_gated = layer.activation.is_gated
if (
is_gated
and is_act_and_mul
and backend
in [
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM,
]
):
w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale)

# For some FI kernels, the input scales are shared by all experts.
Expand All @@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(

# Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels.
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
# Align weights for FI NVFP4 MoE kernels.
min_alignment = 16 if is_gated else 128
w13, w13_scale, w2, w2_scale, padded_intermediate = (
align_fp4_moe_weights_for_fi(
w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment
)
)
layer.intermediate_size_per_partition = padded_intermediate

w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,
w2,
w13_scale,
w2_scale,
w2.size(-2), # hidden_size
w13.size(-2) // 2, # intermediate_size
w13.size(0), # num_experts
hidden_size=w2.size(-2),
intermediate_size=w13.size(-2) // 2 if is_gated else w13.size(-2),
num_experts=w13.size(0),
is_gated_activation=is_gated,
)

# We do not need to make this a parameter, because
# it is not used during the weight (re)-loading process.
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
if is_gated:
layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale
else:
layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale
layer.a1_gscale = 1.0 / a13_scale
layer.g1_alphas = a13_scale * w13_scale_2
layer.g2_alphas = a2_scale * w2_scale_2
Expand Down
Loading