Skip to content

Commit b90f347

Browse files
committed
Fix after refactor
Signed-off-by: Shu Wang. <[email protected]>
1 parent 365a8ff commit b90f347

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def apply(
144144
assert hidden_states.ndim == 3
145145
assert self.w1_scale.ndim == 3
146146
assert self.w2_scale.ndim == 3
147-
148147
flashinfer_cutedsl_moe_masked(
149148
hidden_states=hidden_states,
150149
input_global_scale=self.a1_gscale,
@@ -306,3 +305,42 @@ def flashinfer_cutedsl_moe_masked(
306305
alpha_dtype=get_cute_dtype(w2_alpha),
307306
) # in logical [m, k, l]
308307
out = out.permute(2, 0, 1)
308+
309+
310+
def flashinfer_cutedsl_moe_fp4(
311+
hidden_states: torch.Tensor,
312+
w1: torch.Tensor,
313+
w2: torch.Tensor,
314+
topk_weights: torch.Tensor,
315+
topk_ids: torch.Tensor,
316+
quant_config: FusedMoEQuantConfig,
317+
inplace: bool = False,
318+
activation: str = "silu",
319+
global_num_experts: int = -1,
320+
expert_map: torch.Tensor | None = None,
321+
apply_router_weight_on_input: bool = False,
322+
) -> torch.Tensor:
323+
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
324+
create_flashinfer_prepare_finalize,
325+
)
326+
327+
fused_experts = mk.FusedMoEModularKernel(
328+
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
329+
FlashInferCuteDSLExperts(
330+
out_dtype=hidden_states.dtype,
331+
quant_config=quant_config,
332+
),
333+
)
334+
335+
return fused_experts(
336+
hidden_states=hidden_states,
337+
w1=w1,
338+
w2=w2,
339+
topk_weights=topk_weights,
340+
topk_ids=topk_ids,
341+
inplace=inplace,
342+
activation=activation,
343+
global_num_experts=global_num_experts,
344+
expert_map=expert_map,
345+
apply_router_weight_on_input=apply_router_weight_on_input,
346+
)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,17 +1734,26 @@ def apply(
17341734
workspace=layer.workspace,
17351735
)
17361736

1737-
elif self.allow_flashinfer and self.flashinfer_moe_backend in (
1738-
FlashinferMoeBackend.CUTLASS,
1739-
FlashinferMoeBackend.CUTEDSL,
1740-
):
1741-
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
1742-
flashinfer_cutlass_moe_fp4,
1737+
elif self.allow_flashinfer:
1738+
assert self.flashinfer_moe_backend in (
1739+
FlashinferMoeBackend.CUTLASS,
1740+
FlashinferMoeBackend.CUTEDSL,
17431741
)
1742+
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
1743+
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
1744+
flashinfer_cutlass_moe_fp4,
1745+
)
17441746

1745-
assert self.moe_quant_config is not None
1747+
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
1748+
else:
1749+
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
1750+
flashinfer_cutedsl_moe_fp4,
1751+
)
1752+
1753+
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
17461754

1747-
return flashinfer_cutlass_moe_fp4(
1755+
assert self.moe_quant_config is not None
1756+
return flashinfer_fn_moe_fp4(
17481757
hidden_states=x,
17491758
w1=layer.w13_weight,
17501759
w2=layer.w2_weight,

0 commit comments

Comments
 (0)