Skip to content

Commit 8a224da

Browse files
committed
Make fused version work with cuda graph
Signed-off-by: Shu Wang <[email protected]>
1 parent c063911 commit 8a224da

File tree

4 files changed

+39
-102
lines changed

4 files changed

+39
-102
lines changed

vllm/envs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,9 +1053,6 @@ def get_vllm_port() -> int | None:
10531053
"VLLM_MARLIN_USE_ATOMIC_ADD", "0"
10541054
)
10551055
== "1",
1056-
"VLLM_DEEPEPLL_BF16_DISPATCH": lambda: bool(
1057-
int(os.getenv("VLLM_DEEPEPLL_BF16_DISPATCH", "0"))
1058-
),
10591056
# Whether to use marlin kernel in mxfp4 quantization method
10601057
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
10611058
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,31 +114,30 @@ def _do_quant(
114114
assert isinstance(x, torch.Tensor)
115115

116116
num_experts, max_tokens, hidden_dim = x.size()
117-
if not envs.VLLM_DEEPEPLL_BF16_DISPATCH:
118-
# TODO (varun): Optimization - Use a batched version of quant
119-
x = x.view((-1, hidden_dim))
120-
x, x_scales = moe_kernel_quantize_input(
121-
x,
122-
quant_config.a1_scale,
123-
quant_config.quant_dtype,
124-
quant_config.per_act_token_quant,
125-
quant_config.block_shape,
126-
)
127-
x = x.view((num_experts, -1, hidden_dim))
128-
129-
if quant_config.quant_dtype is not None:
130-
assert x_scales is not None
131-
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
132-
else:
133-
# BF16 dispatch path - no quantization
134-
# TODO([email protected]): enable nvfp4 dispatch once DEEPEP is ready.
135-
logger.info_once("Using BF16 dispatch path for DeepEPLLPrepareAndFinalize")
136-
assert x.dtype == torch.bfloat16, (
137-
"BF16 dispatch requires input to be in BF16"
117+
118+
# TODO (varun): Optimization - Use a batched version of quant
119+
x = x.view((-1, hidden_dim))
120+
q_dtype = quant_config.quant_dtype
121+
122+
if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl":
123+
logger.info_once(
124+
"Skip quantization when using FlashInfer CUTEDSL for "
125+
"ModelOptNvFp4FusedMoE."
138126
)
139-
x_scales = None
140-
x = x.view((num_experts, -1, hidden_dim))
141-
# print(f"after deepepll: x.shape = {x.shape}")
127+
q_dtype = None
128+
129+
x, x_scales = moe_kernel_quantize_input(
130+
x,
131+
quant_config.a1_scale,
132+
q_dtype,
133+
quant_config.per_act_token_quant,
134+
quant_config.block_shape,
135+
)
136+
x = x.view((num_experts, -1, hidden_dim))
137+
138+
if q_dtype is not None:
139+
assert x_scales is not None
140+
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
142141

143142
return x, x_scales
144143

@@ -276,8 +275,6 @@ def _finalize(
276275

277276
# TODO (varun) : Enable zero copy mode
278277
dbo_maybe_run_recv_hook()
279-
# print("xxx"*100, fused_expert_output.shape)
280-
# print("ttt"*100, fused_expert_output.dtype)
281278
_, _, recv_hook = self.buffer.low_latency_combine(
282279
fused_expert_output,
283280
topk_ids,

vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py

Lines changed: 7 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from vllm.utils.flashinfer import (
1414
flashinfer_cutedsl_grouped_gemm_nt_masked,
1515
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
16-
nvfp4_batched_quantize,
17-
silu_and_mul,
16+
scaled_fp4_grouped_quantize,
17+
silu_and_mul_scaled_nvfp4_experts_quantize,
1818
)
1919

2020
logger = init_logger(__name__)
@@ -110,18 +110,9 @@ def workspace_shapes(
110110
- Note: in order for activation chunking to work, the first dimension
111111
of each tuple must be the number of tokens.
112112
"""
113-
# assert a.dim() == 2
114-
# assert aq.dim() == 3
115-
# output_shape = aq.shape
116-
# workspace_dtype = a.dtype
117-
# E = aq.size(0)
118-
# workspace2 = (E, M, N)
119-
# workspace1 = output_shape
120113
output_shape = (local_num_experts, M, K)
121114
workspace2 = (local_num_experts, M, N)
122115
workspace1 = output_shape
123-
# The workspace is determined by `aq`, since it comes after any
124-
# potential communication op and is involved in the expert computation.
125116
return (workspace1, workspace2, output_shape)
126117

127118
def apply(
@@ -182,54 +173,6 @@ def get_cute_dtype(input: torch.Tensor) -> str:
182173
raise ValueError(f"Unsupported cute dtype {input.dtype}")
183174

184175

185-
def scaled_fp4_grouped_quant(
186-
input_tensor: torch.Tensor,
187-
input_global_scale: torch.Tensor,
188-
mask: torch.Tensor,
189-
):
190-
"""
191-
Wrapper around nvfp4_batched_quantize
192-
193-
Args:
194-
input_tensor (Tensor):
195-
- Shape (l, m, k)
196-
input_global_scale (Tensor): Shape (l,)
197-
mask (Tensor): Mask tensor, broadcastable
198-
199-
Returns:
200-
output (Tensor): Quantized tensor, logical shape (m, k//2, l)
201-
output_scales (Tensor): Blockscale tensor, logical shape
202-
(32, 4, rm, 4, rk, l)
203-
"""
204-
num_experts, m, k = input_tensor.shape
205-
206-
sf_vec_size = 16
207-
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
208-
209-
scale_k = k // sf_vec_size
210-
padded_k = (scale_k + (4 - 1)) // 4 * 4
211-
padded_m = (m + (128 - 1)) // 128 * 128
212-
213-
aq, aq_sf = nvfp4_batched_quantize(
214-
input_tensor,
215-
input_global_scale,
216-
mask=mask,
217-
)
218-
219-
# --- re-layout quantized tensor ---
220-
# physical (l, m, k//2) -> logical (m, k//2, l)
221-
output = aq.permute(1, 2, 0)
222-
223-
# --- re-layout blockscales ---
224-
# physical (l, rm, rk, 32, 4, 4) -> logical (32, 4, rm, 4, rk, l)
225-
output_scales = aq_sf.view(torch.float8_e4m3fn).view(
226-
num_experts, padded_m // 128, padded_k // 4, 32, 4, 4
227-
)
228-
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
229-
230-
return output, output_scales
231-
232-
233176
def flashinfer_cutedsl_moe_masked(
234177
hidden_states: torch.Tensor,
235178
input_global_scale: torch.Tensor,
@@ -313,10 +256,10 @@ def flashinfer_cutedsl_moe_masked(
313256
f"w2_alpha must be (l,), got {w2_alpha.shape}"
314257
)
315258

316-
aq, aq_sf = scaled_fp4_grouped_quant(
259+
aq, aq_sf = scaled_fp4_grouped_quantize(
317260
hidden_states,
318-
input_global_scale,
319261
masked_m,
262+
input_global_scale,
320263
)
321264

322265
workspace = workspace.permute(1, 2, 0) # requirement of kernel
@@ -343,11 +286,10 @@ def flashinfer_cutedsl_moe_masked(
343286
) # in logical [m, n, l]
344287

345288
# SILU and quantization
346-
347-
diq, diq_sf = scaled_fp4_grouped_quant(
348-
silu_and_mul(workspace.permute(2, 0, 1)),
349-
a2_global_scale,
289+
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
290+
workspace.permute(2, 0, 1),
350291
masked_m,
292+
a2_global_scale,
351293
)
352294

353295
# Gemm2

vllm/utils/flashinfer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ def wrapper(*args, **kwargs):
101101
)
102102
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
103103
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
104-
silu_and_mul_nvfp4_batched_quantize = _lazy_import_wrapper(
105-
"flashinfer", "silu_and_mul_nvfp4_batched_quantize"
104+
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
105+
"flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
106+
)
107+
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
108+
"flashinfer", "scaled_fp4_grouped_quantize"
106109
)
107-
silu_and_mul = _lazy_import_wrapper("flashinfer", "silu_and_mul")
108110
nvfp4_block_scale_interleave = _lazy_import_wrapper(
109111
"flashinfer", "nvfp4_block_scale_interleave"
110112
)
@@ -194,8 +196,8 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
194196
# Check if all required functions are available
195197
required_functions = [
196198
("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
197-
("flashinfer", "silu_and_mul"),
198-
("flashinfer", "nvfp4_batched_quantize"),
199+
("flashinfer", "scaled_fp4_grouped_quantize"),
200+
("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
199201
]
200202

201203
for module_name, attr_name in required_functions:
@@ -482,9 +484,8 @@ def flashinfer_disable_q_quantization() -> bool:
482484
"flashinfer_cutlass_fused_moe",
483485
"flashinfer_cutedsl_grouped_gemm_nt_masked",
484486
"flashinfer_fp4_quantize",
485-
"silu_and_mul_nvfp4_batched_quantize",
486-
"silu_and_mul",
487-
"nvfp4_batched_quantize",
487+
"silu_and_mul_scaled_nvfp4_experts_quantize",
488+
"scaled_fp4_grouped_quantize",
488489
"nvfp4_block_scale_interleave",
489490
"trtllm_fp4_block_scale_moe",
490491
"autotune",

0 commit comments

Comments
 (0)