diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 29fdbffe470a..c23107965340 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, - register_moe_scaling_factors, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, ) @@ -85,7 +85,7 @@ class TestData: @staticmethod def make_moe_tensors_8bit( - m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" + m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu" ) -> "TestData": is_gated = activation != "relu2_no_mul" @@ -123,12 +123,17 @@ def make_moe_tensors_8bit( all2all_backend="naive", ) - register_moe_scaling_factors(layer) - # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - if reorder: + if is_trtllm: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_scales_for_trtllm_fp8_per_tensor_moe( + layer, + layer.w13_weight_scale, + layer.w13_input_scale, + layer.w2_weight_scale, + layer.w2_input_scale, + ) layer.custom_routing_function = Llama4MoE.custom_routing_function layer.intermediate_size_per_partition = n layer.ep_rank = 0 @@ -162,7 +167,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( set_random_seed(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) topk_weights, topk_ids = Llama4MoE.custom_routing_function( @@ -227,7 +232,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): td = TestData.make_moe_tensors_8bit( - m, k, n, e, reorder=False, activation=activation + m, k, n, e, is_trtllm=False, activation=activation ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 4266514bc94e..3f298f7a5ca2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -452,11 +452,14 @@ def make( - a1_scale: Optional scale to be used for a1. - a2_scale: Optional scale to be used for a2. - g1_alphas: Optional global quantization scales for w1 (for nvfp4). - per-channel scales for w1 (for W4A8 FP8). + Optional per-channel scales for w1 (for W4A8 FP8). + Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). - per-channel scales for w2 (for W4A8 FP8). - - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + Optional per-channel scales for w2 (for W4A8 FP8). + Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). + - a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale). + - a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale). + - w1_bias: Optional biases for w1 (GPT OSS Triton). - w2_bias: Optional biases for w1 (GPT OSS Triton). - w1_zp: Optional w1 zero points for int4/int8 quantization. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 5353830db04b..09c3d9b2190f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,10 +165,10 @@ def apply( ): # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ - self.g1_alphas, - self.a2_gscale, - self.g2_alphas, - self.a1_gscale, + self.g1_alphas, # w13_weight_scale * w13_input_scale + self.a2_gscale, # 1.0 / w2_input_scale + self.g2_alphas, # w2_weight_scale * w2_input_scale + self.a1_scale, ] a1q_scale = None # not passing input_sf in fp8 diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 4135c30e82be..0b0efdafbd4d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -184,13 +184,14 @@ def prepare( self._apply_router_weight_on_input( a1, topk_weights, topk_ids, apply_router_weight_on_input ) - if not self.use_dp and quant_config.quant_dtype == "nvfp4": + is_nvfp4 = quant_config.quant_dtype == "nvfp4" + if not self.use_dp and is_nvfp4: return a1, None, None, topk_ids, topk_weights if not self.use_deepseek_fp8_block_scale: a1q, a1q_scale = moe_kernel_quantize_input( a1, - quant_config.a1_gscale, + quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, @@ -222,7 +223,7 @@ def prepare( topk_weights, topk_ids, a1q = gathered a1q_scale = None - if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None: + if is_nvfp4 and a1q_scale is not None: a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f9104b6bf7f5..1223c6902e5f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -50,7 +50,8 @@ apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, get_flashinfer_moe_backend, - register_moe_scaling_factors, + make_fp8_moe_alpha_scales_for_fi, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -774,6 +775,14 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " "activation function, but got {layer.activation}." ) + dynamic_per_token = ( + not self.block_quant and self.quant_config.activation_scheme != "static" + ) + if self.flashinfer_moe_backend is not None and dynamic_per_token: + raise NotImplementedError( + "FlashInfer FP8 MoE backend does not support dynamic per token " + "activation quantization." + ) def create_weights( self, @@ -905,6 +914,8 @@ def _convert_weights_to_kernel_format( w2_weight: torch.Tensor, w13_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_input_scale: torch.Tensor | None, ) -> None: if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: assert self.block_quant @@ -949,11 +960,16 @@ def _convert_weights_to_kernel_format( if self.block_quant: w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - # TODO(rob): this function is a hack that renames the scaling - # factors in the Module. This is a hack we should clean up. - register_moe_scaling_factors(layer) if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + register_scales_for_trtllm_fp8_per_tensor_moe( + layer=layer, + w13_weight_scale=w13_weight, + w13_input_scale=w13_input_scale, + w2_weight_scale=w2_weight, + w2_input_scale=w2_input_scale, + ) + elif self.fp8_backend == Fp8MoeBackend.AITER: w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight @@ -990,6 +1006,10 @@ def _setup_kernel(self, layer: Module) -> None: AiterExperts, ) + # Flashinfer TRTLLM does not use the modular kernel abstraction. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + return + self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None self.use_inplace = True @@ -1087,7 +1107,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # Shuffle weights into the runtime format. self._convert_weights_to_kernel_format( - layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale + layer=layer, + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, ) # Setup modular kernel for TP case. @@ -1182,6 +1208,11 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: + # TRTLLM does not use Modular Kernel. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + return None + + # MARLIN uses mixed precision W8A16 config. if self.fp8_backend == Fp8MoeBackend.MARLIN: return fp8_w8a16_moe_quant_config( w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), @@ -1189,11 +1220,38 @@ def get_fused_moe_quant_config( block_shape=self.weight_block_size, ) + w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + # Flashinfer CUTLASS per-tensor uses single dq scale + # (alpha = w_scale * a_scale) and inverse a2 scale. + if ( + self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS + and not self.block_quant + ): + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w1_scale, + a1_scale, + w2_scale, + a2_scale, + ) + return fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=(1.0 / a2_scale), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + ) + + # All other backends use normal config. return fp8_w8a8_moe_quant_config( - w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), - w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, block_shape=self.weight_block_size, ) @@ -1414,7 +1472,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # Shuffle weights into the runtime format. self._convert_weights_to_kernel_format( - layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale + layer=layer, + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=layer.w13_weight_scale, + w2_weight_scale=layer.w2_weight_scale, + w13_input_scale=None, + w2_input_scale=None, ) # Setup modular kernel for TP case. diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 115edb2b3a34..b6752d7f9913 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -50,7 +50,8 @@ flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, - register_moe_scaling_factors, + make_fp8_moe_alpha_scales_for_fi, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -947,9 +948,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.flashinfer_moe_backend is not None: if self.moe.is_act_and_mul: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + + # NOTE: this adds some attributes used by the trtllm kernel, + # which does not conform to the modular kernels abstraction (yet). if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - register_moe_scaling_factors(layer) + register_scales_for_trtllm_fp8_per_tensor_moe( + layer=layer, + w13_weight_scale=layer.w13_weight_scale, + w13_input_scale=layer.w13_input_scale, + w2_weight_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + ) def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. @@ -999,19 +1009,34 @@ def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # TRTLLM does not use modular kernels return None - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - g1_alphas=layer.output1_scales_gate_scalar.squeeze(), - w2_scale=layer.w2_weight_scale, - g2_alphas=layer.output2_scales_scalar.squeeze(), - a1_scale=layer.w13_input_scale, - a1_gscale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - a2_gscale=layer.w2_input_scale_inv, - per_act_token_quant=False, - ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + layer.w13_weight_scale, + layer.w13_input_scale, + layer.w2_weight_scale, + layer.w2_input_scale, + ) + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + a1_gscale=(1.0 / layer.w13_input_scale), + a2_gscale=(1.0 / layer.w2_input_scale), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + ) + else: + assert self.flashinfer_moe_backend is None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 3d6e9cda8766..b73c44b3130d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights( ) +def register_scales_for_trtllm_fp8_per_tensor_moe( + layer: torch.nn.Module, + w13_weight_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> None: + """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w13_scale=w13_weight_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_weight_scale, + w2_input_scale=w2_input_scale, + ) + layer.w2_input_scale_inv = 1.0 / w2_input_scale + layer.output1_scales_gate_scalar = g1_alphas + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv + layer.output2_scales_scalar = g2_alphas + + def apply_flashinfer_per_tensor_scale_fp8( layer: torch.nn.Module, hidden_states: torch.Tensor, @@ -117,18 +137,13 @@ def apply_flashinfer_per_tensor_scale_fp8( from flashinfer.fused_moe import RoutingMethodType import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_scalar to be initialized" - ) - assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_gate_scalar to be initialized" + assert ( + hasattr(layer, "output1_scales_scalar") + and hasattr(layer, "output1_scales_gate_scalar") + and hasattr(layer, "output2_scales_scalar") ) - assert layer.output1_scales_scalar is not None, ( - "Expected output2_scales_scalar to be initialized" - ) - - from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" @@ -155,40 +170,16 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -def get_moe_scaling_factors( - input_scale: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, - gemm2_weights_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale) - output1_scales_gate_scalar = gemm1_weights_scale * input_scale - output2_scales_scalar = activation_scale * gemm2_weights_scale - - return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar +def make_fp8_moe_alpha_scales_for_fi( + w13_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + g1_alphas = (w13_scale * w13_input_scale).squeeze() + g2_alphas = (w2_scale * w2_input_scale).squeeze() - -def register_moe_scaling_factors(layer: torch.nn.Module) -> None: - output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( - layer.w13_input_scale, - layer.w13_weight_scale, - layer.w2_input_scale, - layer.w2_weight_scale, - ) - layer.register_parameter( - "output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False) - ) - layer.register_parameter( - "output1_scales_gate_scalar", - torch.nn.Parameter(output1_gate_scales, requires_grad=False), - ) - layer.register_parameter( - "output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False) - ) - layer.register_parameter( - "w2_input_scale_inv", - torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False), - ) + return g1_alphas, g2_alphas def build_flashinfer_fp8_cutlass_moe_prepare_finalize(