Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d6a1f64
initial commit
Dec 31, 2025
9a28683
move to cuda before the reshapes for r&d
Jan 1, 2026
e84eaa2
clean
Jan 1, 2026
058a998
clean
Jan 1, 2026
d53b6ff
clean
Jan 1, 2026
9a7cf4d
clean
Jan 1, 2026
1182e1d
clean
Jan 1, 2026
2126f98
clean
Jan 1, 2026
33741a8
clean
Jan 1, 2026
31c4e22
stash
Jan 1, 2026
e0129dd
working end to end
Jan 1, 2026
eb6699b
comment nits
Jan 1, 2026
5be7ab1
comment nits
Jan 1, 2026
844a65a
remove
Jan 1, 2026
24a0302
rename method
Jan 1, 2026
7edf70f
stash trtllm fix
Jan 1, 2026
9d994a6
stash changes
Jan 1, 2026
6ff4b75
Merge branch 'fix-flashinfer-experts-quant-config-hack' of https://gi…
Jan 1, 2026
c9a7e5b
updated
Jan 1, 2026
2408ad2
make trtllm work
Jan 1, 2026
96ff599
ad back import
Jan 1, 2026
f9a4724
update comments
Jan 1, 2026
e8831f9
apply changes to fp8.py
Jan 1, 2026
f8f9a33
nit
Jan 1, 2026
59f97a6
revert unneeded assets
Jan 1, 2026
113e472
rename
Jan 1, 2026
a98a380
update comment
Jan 1, 2026
df82e9c
Merge branch 'fix-flashinfer-experts-quant-config-hack' of https://gi…
Jan 1, 2026
df5035c
naming
Jan 1, 2026
3678402
add back check to prevent mixtral
Jan 1, 2026
783b64d
Merge branch 'main' into fix-flashinfer-experts-quant-config-hack
robertgshaw2-redhat Jan 1, 2026
c30d404
remove delete me
Jan 2, 2026
a285f5e
update to address pavani's feedback
Jan 2, 2026
2d96161
reduce LOC change
Jan 2, 2026
a910872
fix
Jan 2, 2026
d2decd6
reduce loc nit
Jan 2, 2026
870fc6a
clean up fi checking
Jan 2, 2026
23e79fd
updated
Jan 2, 2026
86a0e5c
fix assign float to parameter
Jan 2, 2026
7eaa18b
updated doc string
Jan 2, 2026
83a7d9b
fix up
Jan 2, 2026
7300bc5
fix up
Jan 2, 2026
344167d
fix up configs
Jan 2, 2026
b887c4f
cleanup
Jan 2, 2026
d4d4231
remove unneeded assert
Jan 2, 2026
218e697
standardize how we add the params
Jan 2, 2026
b2e3a50
updated
Jan 2, 2026
dd30416
updated
Jan 2, 2026
3d22ba3
updated
Jan 2, 2026
56edeca
a few small nits
Jan 2, 2026
e917f5d
fix tests
Jan 2, 2026
39987f6
revert the llama weight loading hack
Jan 2, 2026
140f447
stash
Jan 2, 2026
de6faa1
unstash
Jan 2, 2026
173e67d
Merge remote-tracking branch 'origin/main' into fix-flashinfer-expert…
Jan 3, 2026
12a0638
update cutlass per tensor
Jan 5, 2026
6982255
Merge branch 'main' into fix-flashinfer-experts-quant-config-hack
robertgshaw2-redhat Jan 5, 2026
59db9a9
update cutlass per tensor
Jan 5, 2026
d6c4a87
update
Jan 5, 2026
ce913de
fix trtllm issue
Jan 5, 2026
a8c5cc9
Merge branch 'main' into fix-flashinfer-experts-quant-config-hack
robertgshaw2-redhat Jan 6, 2026
84dc7ea
Merge branch 'main' into fix-flashinfer-experts-quant-config-hack
robertgshaw2-redhat Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31,
)
Expand Down Expand Up @@ -85,7 +85,7 @@ class TestData:

@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
) -> "TestData":
is_gated = activation != "relu2_no_mul"

Expand Down Expand Up @@ -123,12 +123,17 @@ def make_moe_tensors_8bit(
all2all_backend="naive",
)

register_moe_scaling_factors(layer)

# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
if is_trtllm:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
Expand Down Expand Up @@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)

score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = Llama4MoE.custom_routing_function(
Expand Down Expand Up @@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, reorder=False, activation=activation
m, k, n, e, is_trtllm=False, activation=activation
)

score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,14 @@ def make(
- a1_scale: Optional scale to be used for a1.
- a2_scale: Optional scale to be used for a2.
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
per-channel scales for w1 (for W4A8 FP8).
Optional per-channel scales for w1 (for W4A8 FP8).
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
per-channel scales for w2 (for W4A8 FP8).
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
Optional per-channel scales for w2 (for W4A8 FP8).
Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8).
- a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale).
- a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale).

- w1_bias: Optional biases for w1 (GPT OSS Triton).
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ def apply(
):
# FP8 per-tensor path: use global alphas/scales; do not pass input_sf
quant_scales = [
self.g1_alphas,
self.a2_gscale,
self.g2_alphas,
self.a1_gscale,
self.g1_alphas, # w13_weight_scale * w13_input_scale
self.a2_gscale, # 1.0 / w2_input_scale
self.g2_alphas, # w2_weight_scale * w2_input_scale
self.a1_scale,
]

a1q_scale = None # not passing input_sf in fp8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,14 @@ def prepare(
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp and quant_config.quant_dtype == "nvfp4":
is_nvfp4 = quant_config.quant_dtype == "nvfp4"
if not self.use_dp and is_nvfp4:
return a1, None, None, topk_ids, topk_weights

if not self.use_deepseek_fp8_block_scale:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
Expand Down Expand Up @@ -222,7 +223,7 @@ def prepare(
topk_weights, topk_ids, a1q = gathered
a1q_scale = None

if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None:
if is_nvfp4 and a1q_scale is not None:
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)

return a1q, a1q_scale, None, topk_ids, topk_weights
Expand Down
84 changes: 74 additions & 10 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
apply_flashinfer_per_tensor_scale_fp8,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
get_flashinfer_moe_backend,
register_moe_scaling_factors,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
Expand Down Expand Up @@ -774,6 +775,14 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
"activation function, but got {layer.activation}."
)
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
if self.flashinfer_moe_backend is not None and dynamic_per_token:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)

def create_weights(
self,
Expand Down Expand Up @@ -905,6 +914,8 @@ def _convert_weights_to_kernel_format(
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
assert self.block_quant
Expand Down Expand Up @@ -949,11 +960,16 @@ def _convert_weights_to_kernel_format(
if self.block_quant:
w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
else:
# TODO(rob): this function is a hack that renames the scaling
# factors in the Module. This is a hack we should clean up.
register_moe_scaling_factors(layer)
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=w13_weight,
w13_input_scale=w13_input_scale,
w2_weight_scale=w2_weight,
w2_input_scale=w2_input_scale,
)

elif self.fp8_backend == Fp8MoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
w13_weight, w2_weight
Expand Down Expand Up @@ -990,6 +1006,10 @@ def _setup_kernel(self, layer: Module) -> None:
AiterExperts,
)

# Flashinfer TRTLLM does not use the modular kernel abstraction.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return

self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.use_inplace = True
Expand Down Expand Up @@ -1087,7 +1107,13 @@ def process_weights_after_loading(self, layer: Module) -> None:

# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)

# Setup modular kernel for TP case.
Expand Down Expand Up @@ -1182,18 +1208,50 @@ def select_gemm_impl(
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None

# MARLIN uses mixed precision W8A16 config.
if self.fp8_backend == Fp8MoeBackend.MARLIN:
return fp8_w8a16_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
block_shape=self.weight_block_size,
)

w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale

# Flashinfer CUTLASS per-tensor uses single dq scale
# (alpha = w_scale * a_scale) and inverse a2 scale.
if (
self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
and not self.block_quant
):
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w1_scale,
a1_scale,
w2_scale,
a2_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=(1.0 / a2_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)

# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)

Expand Down Expand Up @@ -1414,7 +1472,13 @@ def process_weights_after_loading(self, layer: Module) -> None:

# Shuffle weights into the runtime format.
self._convert_weights_to_kernel_format(
layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale
layer=layer,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=layer.w13_weight_scale,
w2_weight_scale=layer.w2_weight_scale,
w13_input_scale=None,
w2_input_scale=None,
)

# Setup modular kernel for TP case.
Expand Down
51 changes: 38 additions & 13 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
register_moe_scaling_factors,
make_fp8_moe_alpha_scales_for_fi,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
Expand Down Expand Up @@ -947,9 +948,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.flashinfer_moe_backend is not None:
if self.moe.is_act_and_mul:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)

# NOTE: this adds some attributes used by the trtllm kernel,
# which does not conform to the modular kernels abstraction (yet).
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
register_moe_scaling_factors(layer)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer=layer,
w13_weight_scale=layer.w13_weight_scale,
w13_input_scale=layer.w13_input_scale,
w2_weight_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
)

def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
"""Pad intermediate size so FlashInfer kernels' alignment constraints hold.
Expand Down Expand Up @@ -999,19 +1009,34 @@ def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
# TRTLLM does not use modular kernels
return None

return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
g1_alphas=layer.output1_scales_gate_scalar.squeeze(),
w2_scale=layer.w2_weight_scale,
g2_alphas=layer.output2_scales_scalar.squeeze(),
a1_scale=layer.w13_input_scale,
a1_gscale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a2_gscale=layer.w2_input_scale_inv,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the a2_gscales are used for quantization of hidden states (a tensor) before FFN2 in MOE, hence the a2 (for second FFN) and gscale for quantization.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to reflect this

per_act_token_quant=False,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
a1_gscale=(1.0 / layer.w13_input_scale),
Copy link
Copy Markdown
Contributor

@amirkl94 amirkl94 Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is called every forward, which means these 2 lines will result in 2 kernel launches for reciprocal:

a1_gscale=(1.0 / layer.w13_input_scale),
a2_gscale=(1.0 / layer.w2_input_scale),

Can we add these 2 scales in process_weights_after_loading ?

Copy link
Copy Markdown
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not called in the forward pass. I recognize this is confusing, but the apply() method is not called during the forward pass for flashinfer kernels. When flashinfer CUTLASS kernels are selected, the FpMoeMethod is converted into a ModularKernelMethod

I am working on an ongoing refactor that makes the conversion

see https://vllm-dev.slack.com/archives/C08NFPURQ1F/p1767650816469009 for more details on my efforts

a2_gscale=(1.0 / layer.w2_input_scale),
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
)
else:
assert self.flashinfer_moe_backend is None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)

def apply(
self,
Expand Down
Loading