Skip to content

Commit c063911

Browse files
committed
Add flashinfer_cutedsl grouped gemm
Signed-off-by: Shu Wang <[email protected]>
1 parent 2935092 commit c063911

File tree

9 files changed

+1031
-38
lines changed

9 files changed

+1031
-38
lines changed

tests/kernels/moe/test_cutedsl_moe.py

Lines changed: 527 additions & 0 deletions
Large diffs are not rendered by default.

vllm/envs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@
156156
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
157157
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
158158
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
159-
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
159+
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "cutedsl"] = (
160+
"throughput"
161+
)
160162
VLLM_XGRAMMAR_CACHE_MB: int = 0
161163
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
162164
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -1051,6 +1053,9 @@ def get_vllm_port() -> int | None:
10511053
"VLLM_MARLIN_USE_ATOMIC_ADD", "0"
10521054
)
10531055
== "1",
1056+
"VLLM_DEEPEPLL_BF16_DISPATCH": lambda: bool(
1057+
int(os.getenv("VLLM_DEEPEPLL_BF16_DISPATCH", "0"))
1058+
),
10541059
# Whether to use marlin kernel in mxfp4 quantization method
10551060
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
10561061
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
@@ -1199,7 +1204,9 @@ def get_vllm_port() -> int | None:
11991204
# - "latency":
12001205
# Uses TensorRT-LLM kernels optimized for low-latency inference.
12011206
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
1202-
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
1207+
"VLLM_FLASHINFER_MOE_BACKEND",
1208+
"throughput",
1209+
["throughput", "latency", "cutedsl"],
12031210
),
12041211
# Control the maximum number of tokens per expert supported by the
12051212
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9+
from vllm import envs
10+
from vllm.logger import init_logger
911
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1012
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1113
TopKWeightAndReduceDelegate,
@@ -24,6 +26,8 @@
2426
DEEPEP_QUANT_BLOCK_SIZE = 128
2527
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
2628

29+
logger = init_logger(__name__)
30+
2731

2832
def dequant_fp8(
2933
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
@@ -110,21 +114,31 @@ def _do_quant(
110114
assert isinstance(x, torch.Tensor)
111115

112116
num_experts, max_tokens, hidden_dim = x.size()
113-
114-
# TODO (varun): Optimization - Use a batched version of quant
115-
x = x.view((-1, hidden_dim))
116-
x, x_scales = moe_kernel_quantize_input(
117-
x,
118-
quant_config.a1_scale,
119-
quant_config.quant_dtype,
120-
quant_config.per_act_token_quant,
121-
quant_config.block_shape,
122-
)
123-
x = x.view((num_experts, -1, hidden_dim))
124-
125-
if quant_config.quant_dtype is not None:
126-
assert x_scales is not None
127-
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
117+
if not envs.VLLM_DEEPEPLL_BF16_DISPATCH:
118+
# TODO (varun): Optimization - Use a batched version of quant
119+
x = x.view((-1, hidden_dim))
120+
x, x_scales = moe_kernel_quantize_input(
121+
x,
122+
quant_config.a1_scale,
123+
quant_config.quant_dtype,
124+
quant_config.per_act_token_quant,
125+
quant_config.block_shape,
126+
)
127+
x = x.view((num_experts, -1, hidden_dim))
128+
129+
if quant_config.quant_dtype is not None:
130+
assert x_scales is not None
131+
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
132+
else:
133+
# BF16 dispatch path - no quantization
134+
# TODO([email protected]): enable nvfp4 dispatch once DEEPEP is ready.
135+
logger.info_once("Using BF16 dispatch path for DeepEPLLPrepareAndFinalize")
136+
assert x.dtype == torch.bfloat16, (
137+
"BF16 dispatch requires input to be in BF16"
138+
)
139+
x_scales = None
140+
x = x.view((num_experts, -1, hidden_dim))
141+
# print(f"after deepepll: x.shape = {x.shape}")
128142

129143
return x, x_scales
130144

@@ -262,6 +276,8 @@ def _finalize(
262276

263277
# TODO (varun) : Enable zero copy mode
264278
dbo_maybe_run_recv_hook()
279+
# print("xxx"*100, fused_expert_output.shape)
280+
# print("ttt"*100, fused_expert_output.dtype)
265281
_, _, recv_hook = self.buffer.low_latency_combine(
266282
fused_expert_output,
267283
topk_ids,

0 commit comments

Comments
 (0)