Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
max_capture_size,
tune_max_num_tokens,
):
super().__init__(quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.max_capture_size = max_capture_size
self.tune_max_num_tokens = tune_max_num_tokens

@property
def activation_formats(
Expand Down Expand Up @@ -127,7 +127,7 @@ def apply(
"routing_method_type": 1,
"do_finalize": True,
"output": output,
"tune_max_num_tokens": max(self.max_capture_size, 1),
"tune_max_num_tokens": self.tune_max_num_tokens,
}

from flashinfer import trtllm_fp4_block_scale_routed_moe
Expand Down
12 changes: 7 additions & 5 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,10 @@ def __init__(self, moe: FusedMoEConfig):

self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
# Be conservative and tune for the most extreme inbalance for MoE,
# i.e., one expert receives all the tokens.
self.tune_max_num_tokens = (
get_current_vllm_config().scheduler_config.max_num_batched_tokens
)
Comment thread
nvjullin marked this conversation as resolved.

assert self.mxfp4_backend != Mxfp4Backend.NONE, (
Expand Down Expand Up @@ -868,7 +870,7 @@ def select_gemm_impl(
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
"tune_max_num_tokens": self.tune_max_num_tokens,
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
Expand Down Expand Up @@ -981,7 +983,7 @@ def apply(
None,
1 if layer.renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=max(self.max_capture_size, 1),
tune_max_num_tokens=self.tune_max_num_tokens,
)[0]
return trtllm_gen_output
elif (
Expand Down Expand Up @@ -1048,7 +1050,7 @@ def apply(
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
tune_max_num_tokens=self.tune_max_num_tokens,
**extra_kwargs,
)

Expand Down