Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/compile/fusions_e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def run(
f"attention backend '{attn_backend.backend.name}'"
)

# TODO: remove this after finishing migration from envs to model kwargs
if model_name == "openai/gpt-oss-20b":
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")

# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
Expand Down
9 changes: 9 additions & 0 deletions tests/compile/fusions_e2e/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,12 @@
# async_tp=n_layers * 2,
),
)

gpt_oss_20b = ModelFusionInfo(
model_name="openai/gpt-oss-20b",
matches=lambda n_layers: Matches(
ar_rms_fusion=n_layers * 2 + 1,
sequence_parallel=n_layers * 2 + 1,
async_tp=n_layers * 2,
),
)
3 changes: 2 additions & 1 deletion tests/compile/fusions_e2e/test_tp2_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FLASHINFER_MLA_ATTN,
TRITON_ATTN,
deepseek_v3_fp8,
gpt_oss_20b,
llama3_8b,
llama3_8b_fp4,
llama3_8b_fp8,
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b],
[llama3_8b, qwen3_a3b, gpt_oss_20b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize("n_layers", [4])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def topk_indices_dtype(self) -> torch.dtype | None:
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None

@property
def skip_forward_padding(self) -> bool:
"""Whether to skip the padding in the forward before applying the moe method."""
return False

@property
def supports_eplb(self) -> bool:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,10 @@ def forward(

# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim = hidden_states.shape[-1]
if self.moe_config.hidden_dim != transformed_hidden_dim:
if (
not self.quant_method.skip_forward_padding
and self.moe_config.hidden_dim != transformed_hidden_dim
):
hidden_states = F.pad(
hidden_states,
(0, self.moe_config.hidden_dim - transformed_hidden_dim),
Expand Down
17 changes: 16 additions & 1 deletion vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def __init__(self, moe: FusedMoEConfig):
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_kernel: mk.FusedMoEKernel | None = None

@property
def skip_forward_padding(self) -> bool:
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
# so can skip the padding in the forward before applying the moe method
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -1130,9 +1136,17 @@ def apply_monolithic(
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize

x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
# x_quant is padded in hidden dimension with alignment=256
x_quant, x_scale = mxfp8_quantize(
x,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)

# output with original unpadded hidden size
output = torch.empty_like(x)

trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
Expand Down Expand Up @@ -1161,6 +1175,7 @@ def apply_monolithic(
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
output=output,
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
Expand Down
Loading