diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py index 6481434f2e78..38200d9d0905 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py @@ -54,6 +54,11 @@ def __init__( self.out_dtype = moe_config.in_dtype self.num_local_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank + # FC2 input scale tensor bound in process_weights_after_loading: the + # calibrated (now-zeroed) a2_gscale for static-quant checkpoints, or + # a synthesized uniform-1.0 tensor for W4A16 checkpoints that lack + # one. Holding it on the instance keeps apply() alloc-free. + self._fc2_input_scale: torch.Tensor | None = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Normalise block scales to absorb the per-expert weight global scale @@ -86,6 +91,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # its own per-block dynamic scale. if self.a2_gscale is not None: self.a2_gscale.fill_(1.0) + self._fc2_input_scale = self.a2_gscale + else: + # W4A16 NVFP4 checkpoints have no calibrated a2_gscale; b12x + # performs dynamic per-block FC2-input quantization, so a uniform + # 1.0 scale per expert is equivalent to the bake-in above for + # static-quant checkpoints. Allocate once here so apply() stays + # alloc-free. + self._fc2_input_scale = torch.ones( + self.num_local_experts, + device=layer.w13_weight.device, + dtype=torch.float32, + ) # Precompute MMA-layout views of the weight scale factors once here # rather than recomputing on every forward pass. @@ -131,7 +148,13 @@ def _supports_quant_scheme( weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) + # b12x performs in-kernel BF16->FP4 activation quant, so W4A16 + # NVFP4 checkpoints (activation_key=None, e.g. mixed-precision + # compressed-tensors layouts) are runtime-compatible. + return (weight_key, activation_key) in ( + (kNvfp4Static, kNvfp4Dynamic), + (kNvfp4Static, None), + ) @staticmethod def _supports_activation(activation: MoEActivation) -> bool: @@ -198,8 +221,8 @@ def apply( assert self.g1_alphas is not None and self.g2_alphas is not None, ( "g1_alphas and g2_alphas must not be None for FlashInferB12xExperts" ) - assert self.a2_gscale is not None, ( - "a2_gscale must not be None for FlashInferB12xExperts" + assert self._fc2_input_scale is not None, ( + "_fc2_input_scale must be set by process_weights_after_loading" ) top_k = topk_ids.shape[1] @@ -211,7 +234,7 @@ def apply( w1_weight=w1, w1_weight_sf=self.w1_sf_mma, w1_alpha=self.g1_alphas, - fc2_input_scale=self.a2_gscale, + fc2_input_scale=self._fc2_input_scale, w2_weight=w2, w2_weight_sf=self.w2_sf_mma, w2_alpha=self.g2_alphas,