From d6a1f6405628b7f6120765f10996b7e8e8f1cb3b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Wed, 31 Dec 2025 22:32:24 +0000 Subject: [PATCH 01/55] initial commit Signed-off-by: Robert Shaw --- .../fused_moe/flashinfer_cutlass_moe.py | 8 +++--- .../flashinfer_cutlass_prepare_finalize.py | 8 +++--- .../layers/quantization/modelopt.py | 26 ++++++++++++++----- .../quantization/utils/flashinfer_utils.py | 12 +++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index f864634c6617..603b7a5da743 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,10 +165,10 @@ def apply( ): # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ - self.g1_alphas, - self.a2_gscale, - self.g2_alphas, - self.a1_gscale, + self.w1_scale, # done: g1_alphas, layer.output1_scales_gate_scalar.squeeze(), # noqa: E501 + self.w2_scale, # done: a2_gscale, layer.output2_scales_scalar.squeeze(), # noqa: E501 + self.a2_scale, # done: self.g2_alphas, # layer.w2_input_scale_inv + self.a1_scale, # done: self.a1_gscale, # layer.w13_input_scale ] a1q_scale = None # not passing input_sf in fp8 diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 762890867e60..1a4f6598c296 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -181,13 +181,15 @@ def prepare( self._apply_router_weight_on_input( a1, topk_weights, topk_ids, apply_router_weight_on_input ) - if not self.use_dp and quant_config.quant_dtype == "nvfp4": + is_nvfp4 = quant_config.quant_dtype == "nvfp4" + if not self.use_dp and is_nvfp4: return a1, None, None, topk_ids, topk_weights if not self.use_deepseek_fp8_block_scale: a1q, a1q_scale = moe_kernel_quantize_input( a1, - quant_config.a1_gscale, + # TODO(rob): look at the signature of the nvfp4 case. + quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, @@ -219,7 +221,7 @@ def prepare( topk_weights, topk_ids, a1q = gathered a1q_scale = None - if quant_config.quant_dtype == "nvfp4" and a1q_scale is not None: + if is_nvfp4 and a1q_scale is not None: a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index afbefe1fedc1..836e0ff90f21 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -45,12 +45,13 @@ ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, + # register_moe_scaling_factors, + _convert_moe_scaling_factors, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -945,7 +946,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - register_moe_scaling_factors(layer) + # register_moe_scaling_factors(layer) + w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( + _convert_moe_scaling_factors( + layer.w13_weight_scale, + layer.w13_input_scale, + layer.w2_weight_scale, + layer.w2_input_scale, + ) + ) + + layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) + layer.w13_input_scale = Parameter(w13_input_scale, requires_grad=False) + layer.w2_input_scale = Parameter(w2_input_scale, requires_grad=False) def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. @@ -999,13 +1013,13 @@ def get_fused_moe_quant_config( return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, - g1_alphas=layer.output1_scales_gate_scalar.squeeze(), w2_scale=layer.w2_weight_scale, - g2_alphas=layer.output2_scales_scalar.squeeze(), a1_scale=layer.w13_input_scale, - a1_gscale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - a2_gscale=layer.w2_input_scale_inv, + # g1_alphas=layer.output1_scales_gate_scalar.squeeze(), + # g2_alphas=layer.output2_scales_scalar.squeeze(), + # a1_gscale=layer.w13_input_scale, + # a2_gscale=layer.w2_input_scale_inv, per_act_token_quant=False, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 3d6e9cda8766..67d3f7ee830d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -168,6 +168,18 @@ def get_moe_scaling_factors( return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar +def _convert_moe_scaling_factors( + w13_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + w13_scale = (w13_scale * w13_input_scale).squeeze() + w2_scale = (w2_scale * w2_input_scale).squeeze() + w2_input_scale = 1.0 / w2_input_scale + return w13_scale, w13_input_scale, w2_scale, w2_input_scale + + def register_moe_scaling_factors(layer: torch.nn.Module) -> None: output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( layer.w13_input_scale, From 9a286832bd4c21ae99ea4825bf279734af2dfc86 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:41:48 +0000 Subject: [PATCH 02/55] move to cuda before the reshapes for r&d Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_moe.py | 4 ++-- vllm/model_executor/layers/fused_moe/layer.py | 16 ++++++++++++++++ vllm/model_executor/models/llama4.py | 16 ++++++++++++++++ vllm/model_executor/models/mllama4.py | 4 ++++ vllm/model_executor/models/utils.py | 7 +++++++ 5 files changed, 45 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 603b7a5da743..214b32142a11 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,8 +165,8 @@ def apply( ): # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ - self.w1_scale, # done: g1_alphas, layer.output1_scales_gate_scalar.squeeze(), # noqa: E501 - self.w2_scale, # done: a2_gscale, layer.output2_scales_scalar.squeeze(), # noqa: E501 + self.w1_scale, # done: g1_alphas, layer.output1_scales_gate_scalar.squeeze(), # noqa: E501 + self.w2_scale, # done: a2_gscale, layer.output2_scales_scalar.squeeze(), # noqa: E501 self.a2_scale, # done: self.g2_alphas, # layer.w2_input_scale_inv self.a1_scale, # done: self.a1_gscale, # layer.w13_input_scale ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f0d94bfbcaba..816703e7dcfb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -951,6 +951,9 @@ def _load_model_weight_or_group_weight_scale( load_full=load_full_w2, ) elif shard_id in ("w1", "w3"): + logger.info( + f"{loaded_weight.is_contiguous()=} | {expert_data.is_contiguous()=}" + ) self._load_w13( shard_id=shard_id, shard_dim=shard_dim, @@ -1006,7 +1009,16 @@ def _load_w13( else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + import time + + start = time.perf_counter() + loaded_weight = loaded_weight.contiguous() + done_cont = time.perf_counter() expert_data.copy_(loaded_weight) + done = time.perf_counter() + logger.info( + f"{done_cont - start:.3f}s to contiguous, {done - done_cont:.3f}s to copy" + ) def _load_w2( self, @@ -1259,6 +1271,7 @@ def weight_loader( else "weight_scale" in weight_name ) or "input_scale" in weight_name if is_per_tensor: + logger.info("_load_per_tensor_weight_scale") self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, @@ -1276,6 +1289,7 @@ def weight_loader( loaded_weight_hidden_out = loaded_weight.shape[-2] param_hidden_out = param.data.shape[-2] * self.tp_size if loaded_weight_hidden_out == param_hidden_out: + logger.info("_load_combined_w13_weight_scale") self._load_combined_w13_weight_scale( shard_dim=shard_dim, loaded_weight=loaded_weight, @@ -1287,6 +1301,7 @@ def weight_loader( # For other weights, call _load_model_weight_or_group_weight_scale() # to load it. if "weight" in weight_name: + logger.info(f"_load_model_weight_or_group_weight_scale: {weight_name}") self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, @@ -1348,6 +1363,7 @@ def weight_loader( # Case model weights if "weight" in weight_name: + # logger.info("H: _load_model_weight_or_group_weight_scale") self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 7b3da3e10ab8..731dae5175d2 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,6 +432,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False + loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last @@ -520,6 +521,21 @@ def load_moe_expert_weights( loaded_params.add(full_param_name) expert_param_loaded = True + # start = time.perf_counter() + # if isinstance(loaded_weight, tuple): + # for w in loaded_weight: + # w = w.contiguous() + # w = w.to("cpu") + # for w in new_loaded_weight: + # w = w.contiguous() + # w = w.to("cpu") + # else: + # loaded_weight = loaded_weight.contiguous() + # loaded_weight = loaded_weight.to("cpu") + # new_loaded_weight = new_loaded_weight.contiguous() + # new_loaded_weight = new_loaded_weight.to("cpu") + end = time.perf_counter() + logger.info(f"move to cpu time: {end - start}s") return expert_param_loaded def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 886d5151e43f..996f1c80e09c 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -1123,18 +1123,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: updated_params.update(updated_params_from_experts) loader = AutoWeightsLoader(self) + print("starting load weights...") loaded_language_model_params = loader.load_weights(regular_weights) assert loaded_language_model_params is not None updated_params.update(loaded_language_model_params) + print("finished load weights...") if expert_scale_weights: loaded_expert_scale_params = loader.load_weights(expert_scale_weights) if loaded_expert_scale_params: updated_params.update(loaded_expert_scale_params) + print("finished load scales...") updated_params.update( self._load_other_weights(other_weights, params_dict, stacked_params_mapping) ) + print("finished load other weights...") return updated_params diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f25ab9153a50..e177e3fac52e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -195,6 +195,7 @@ def _load_param( param: nn.Parameter, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: + logger.info("Load param!!!!!") for weight_name, weight_data in weights: weight_qualname = self._get_qualname(base_prefix, weight_name) @@ -215,6 +216,7 @@ def _load_param( ) weight_loader = getattr(param, "weight_loader", default_weight_loader) + logger.info(f"{weight_qualname=}: {weight_data.is_contiguous()=}") weight_loader(param, weight_data) logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape) @@ -250,6 +252,7 @@ def _load_module( module: nn.Module, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: + logger.info("load module!!!!!") if isinstance(module, PPMissingLayer): return @@ -258,6 +261,7 @@ def _load_module( if module != self.module: module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): + logger.info("calling module_load_weights") loaded_params = module_load_weights(weights) if loaded_params is None: logger.warning( @@ -285,6 +289,7 @@ def _load_module( continue + logger.info("calling _load_module from _load_module") yield from self._load_module( prefix, child_modules[child_prefix], child_weights ) @@ -294,6 +299,7 @@ def _load_module( continue + logger.info("calling _load_param from _load_module") yield from self._load_param( prefix, child_params[child_prefix], child_weights ) @@ -325,6 +331,7 @@ def load_weights( *, mapper: WeightsMapper | None = None, ) -> set[str]: + logger.info(f"load_weights: {mapper=}") if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name From e84eaa2cc8d4b222378c2daed67858dc85c36db3 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:43:59 +0000 Subject: [PATCH 03/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 4 ++-- vllm/model_executor/models/mllama4.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 731dae5175d2..3a6cb10e3c9f 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -534,8 +534,8 @@ def load_moe_expert_weights( # loaded_weight = loaded_weight.to("cpu") # new_loaded_weight = new_loaded_weight.contiguous() # new_loaded_weight = new_loaded_weight.to("cpu") - end = time.perf_counter() - logger.info(f"move to cpu time: {end - start}s") + # end = time.perf_counter() + # logger.info(f"move to cpu time: {end - start}s") return expert_param_loaded def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 996f1c80e09c..886d5151e43f 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -1123,22 +1123,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: updated_params.update(updated_params_from_experts) loader = AutoWeightsLoader(self) - print("starting load weights...") loaded_language_model_params = loader.load_weights(regular_weights) assert loaded_language_model_params is not None updated_params.update(loaded_language_model_params) - print("finished load weights...") if expert_scale_weights: loaded_expert_scale_params = loader.load_weights(expert_scale_weights) if loaded_expert_scale_params: updated_params.update(loaded_expert_scale_params) - print("finished load scales...") updated_params.update( self._load_other_weights(other_weights, params_dict, stacked_params_mapping) ) - print("finished load other weights...") return updated_params From 058a998849a4a3550d771dda999baff5b0ba5fef Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:45:03 +0000 Subject: [PATCH 04/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/models/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index e177e3fac52e..965cd2530e41 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -252,7 +252,6 @@ def _load_module( module: nn.Module, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: - logger.info("load module!!!!!") if isinstance(module, PPMissingLayer): return @@ -299,7 +298,6 @@ def _load_module( continue - logger.info("calling _load_param from _load_module") yield from self._load_param( prefix, child_params[child_prefix], child_weights ) @@ -331,7 +329,6 @@ def load_weights( *, mapper: WeightsMapper | None = None, ) -> set[str]: - logger.info(f"load_weights: {mapper=}") if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name From d53b6ff6cd6bb31fd63116d1fcd0dbb48f6d0954 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:45:24 +0000 Subject: [PATCH 05/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 816703e7dcfb..a926f3644337 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -951,9 +951,6 @@ def _load_model_weight_or_group_weight_scale( load_full=load_full_w2, ) elif shard_id in ("w1", "w3"): - logger.info( - f"{loaded_weight.is_contiguous()=} | {expert_data.is_contiguous()=}" - ) self._load_w13( shard_id=shard_id, shard_dim=shard_dim, From 9a7cf4d71c7043de8c282b3a663a784e1ebff70d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:45:59 +0000 Subject: [PATCH 06/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a926f3644337..68e01e744441 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1006,16 +1006,7 @@ def _load_w13( else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - import time - - start = time.perf_counter() - loaded_weight = loaded_weight.contiguous() - done_cont = time.perf_counter() expert_data.copy_(loaded_weight) - done = time.perf_counter() - logger.info( - f"{done_cont - start:.3f}s to contiguous, {done - done_cont:.3f}s to copy" - ) def _load_w2( self, From 1182e1dd1a6de3f6cff0aebba1857415814ac964 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:46:54 +0000 Subject: [PATCH 07/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 68e01e744441..f0d94bfbcaba 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1259,7 +1259,6 @@ def weight_loader( else "weight_scale" in weight_name ) or "input_scale" in weight_name if is_per_tensor: - logger.info("_load_per_tensor_weight_scale") self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, @@ -1277,7 +1276,6 @@ def weight_loader( loaded_weight_hidden_out = loaded_weight.shape[-2] param_hidden_out = param.data.shape[-2] * self.tp_size if loaded_weight_hidden_out == param_hidden_out: - logger.info("_load_combined_w13_weight_scale") self._load_combined_w13_weight_scale( shard_dim=shard_dim, loaded_weight=loaded_weight, @@ -1289,7 +1287,6 @@ def weight_loader( # For other weights, call _load_model_weight_or_group_weight_scale() # to load it. if "weight" in weight_name: - logger.info(f"_load_model_weight_or_group_weight_scale: {weight_name}") self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, @@ -1351,7 +1348,6 @@ def weight_loader( # Case model weights if "weight" in weight_name: - # logger.info("H: _load_model_weight_or_group_weight_scale") self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, From 2126f980dd76b75ab6248af4ce1cdb9e66b8515b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:47:24 +0000 Subject: [PATCH 08/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 3a6cb10e3c9f..b9638bb12948 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,7 +432,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False - loaded_weight = loaded_weight.to("cuda") + loaded_weight = loaded_weight.to("cuda") # DELETE ME!!!! # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last @@ -521,21 +521,6 @@ def load_moe_expert_weights( loaded_params.add(full_param_name) expert_param_loaded = True - # start = time.perf_counter() - # if isinstance(loaded_weight, tuple): - # for w in loaded_weight: - # w = w.contiguous() - # w = w.to("cpu") - # for w in new_loaded_weight: - # w = w.contiguous() - # w = w.to("cpu") - # else: - # loaded_weight = loaded_weight.contiguous() - # loaded_weight = loaded_weight.to("cpu") - # new_loaded_weight = new_loaded_weight.contiguous() - # new_loaded_weight = new_loaded_weight.to("cpu") - # end = time.perf_counter() - # logger.info(f"move to cpu time: {end - start}s") return expert_param_loaded def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From 33741a8f29966a33457d8e89aeaeeb7f2dacbf18 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 00:48:01 +0000 Subject: [PATCH 09/55] clean Signed-off-by: Robert Shaw --- vllm/model_executor/models/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 965cd2530e41..f25ab9153a50 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -195,7 +195,6 @@ def _load_param( param: nn.Parameter, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[str]: - logger.info("Load param!!!!!") for weight_name, weight_data in weights: weight_qualname = self._get_qualname(base_prefix, weight_name) @@ -216,7 +215,6 @@ def _load_param( ) weight_loader = getattr(param, "weight_loader", default_weight_loader) - logger.info(f"{weight_qualname=}: {weight_data.is_contiguous()=}") weight_loader(param, weight_data) logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape) @@ -260,7 +258,6 @@ def _load_module( if module != self.module: module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): - logger.info("calling module_load_weights") loaded_params = module_load_weights(weights) if loaded_params is None: logger.warning( @@ -288,7 +285,6 @@ def _load_module( continue - logger.info("calling _load_module from _load_module") yield from self._load_module( prefix, child_modules[child_prefix], child_weights ) From 31c4e229a9b2354511415105b4ccc1d769c9f303 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 14:56:26 +0000 Subject: [PATCH 10/55] stash Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_moe.py | 10 ++++++---- vllm/model_executor/layers/quantization/modelopt.py | 12 ++++-------- .../layers/quantization/utils/flashinfer_utils.py | 10 +++++++--- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 214b32142a11..0d9fc4789b96 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,11 +165,13 @@ def apply( ): # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ - self.w1_scale, # done: g1_alphas, layer.output1_scales_gate_scalar.squeeze(), # noqa: E501 - self.w2_scale, # done: a2_gscale, layer.output2_scales_scalar.squeeze(), # noqa: E501 - self.a2_scale, # done: self.g2_alphas, # layer.w2_input_scale_inv - self.a1_scale, # done: self.a1_gscale, # layer.w13_input_scale + self.w1_scale, # done: g1_alphas + self.w2_scale, # done: a2_gscale + self.a2_scale, # done: g2_alphas + self.a1_scale, # done: a1_gscale ] + for i, quant_scale in enumerate(quant_scales): + logger.info_once(f"{i}, {quant_scale.shape=}") a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 836e0ff90f21..6938871efc0b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -949,10 +949,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # register_moe_scaling_factors(layer) w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( _convert_moe_scaling_factors( - layer.w13_weight_scale, - layer.w13_input_scale, - layer.w2_weight_scale, - layer.w2_input_scale, + w13_scale=layer.w13_weight_scale, + w13_input_scale=layer.w13_input_scale, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, ) ) @@ -1016,10 +1016,6 @@ def get_fused_moe_quant_config( w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - # g1_alphas=layer.output1_scales_gate_scalar.squeeze(), - # g2_alphas=layer.output2_scales_scalar.squeeze(), - # a1_gscale=layer.w13_input_scale, - # a2_gscale=layer.w2_input_scale_inv, per_act_token_quant=False, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 67d3f7ee830d..bd3dd7971cbd 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -174,9 +174,13 @@ def _convert_moe_scaling_factors( w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - w13_scale = (w13_scale * w13_input_scale).squeeze() - w2_scale = (w2_scale * w2_input_scale).squeeze() - w2_input_scale = 1.0 / w2_input_scale + w13_scale = ( + w13_scale * w13_input_scale + ).squeeze() # output1_scales_gate_scalar, g1_alphas + w2_scale = ( + w2_scale * w2_input_scale + ).squeeze() # output2_scales_scalar, g2_alphas + w2_input_scale = 1.0 / w2_input_scale # w2_input_scale_inv, a2_gscale return w13_scale, w13_input_scale, w2_scale, w2_input_scale From e0129dda9e584e829365fd2a93d91659be11ce57 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:17:46 +0000 Subject: [PATCH 11/55] working end to end Signed-off-by: Robert Shaw --- .../fused_moe/flashinfer_cutlass_moe.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 0d9fc4789b96..b91158b62970 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -158,29 +158,26 @@ def apply( f"{activation=} missing from {activation_str_to_value_map.keys()=}" ) - # Select quantization metadata based on FP8 format/path + # Select quantization arguments based on FP8 format/path if ( self.quant_dtype == torch.float8_e4m3fn and not self.use_deepseek_fp8_block_scale ): - # FP8 per-tensor path: use global alphas/scales; do not pass input_sf + # FP8 PER TENSOR: + # * scales are computed in process_weights as per below + # * input is quantized quant_scales = [ - self.w1_scale, # done: g1_alphas - self.w2_scale, # done: a2_gscale - self.a2_scale, # done: g2_alphas - self.a1_scale, # done: a1_gscale + self.w1_scale, # w1_scale = w1_scale * w13_input_scale + self.a2_scale, # a2_scale = 1.0 / a2_input_scale + self.w2_scale, # w2_scale = w2_scale * w2_input_scale + self.a1_scale, # a1_scale = w1_input_scale ] - for i, quant_scale in enumerate(quant_scales): - logger.info_once(f"{i}, {quant_scale.shape=}") - a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 fc2_expert_weights = w2 elif self.quant_dtype == "nvfp4": - # Ensure w1_scale and w2_scale are not None before calling view - assert self.w1_scale is not None and self.w2_scale is not None, ( - "w1_scale and w2_scale must not be None for FlashInferExperts" - ) + # NVFP4 + assert self.w1_scale is not None and self.w2_scale is not None # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ @@ -195,7 +192,9 @@ def apply( fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) elif self.use_deepseek_fp8_block_scale: - # FP8 block-scale path: provide block-scale weights, omit a1q_scale + # FP8 BLOCK: + # * dynamic input quantization + # * calculation of input scales is fused in the kernel. quant_scales = [ self.w1_scale, self.w2_scale, From eb6699b09f3f8afb86078d713025fca6bed61910 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:18:49 +0000 Subject: [PATCH 12/55] comment nits Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b91158b62970..a71213a495b3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,7 +165,7 @@ def apply( ): # FP8 PER TENSOR: # * scales are computed in process_weights as per below - # * input is quantized + # * hidden states are quantized quant_scales = [ self.w1_scale, # w1_scale = w1_scale * w13_input_scale self.a2_scale, # a2_scale = 1.0 / a2_input_scale From 5be7ab1e85619413984e9a698630e6d662fbd0b3 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:20:08 +0000 Subject: [PATCH 13/55] comment nits Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a71213a495b3..2420b249aab4 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -193,8 +193,8 @@ def apply( fc2_expert_weights = w2.view(torch.long) elif self.use_deepseek_fp8_block_scale: # FP8 BLOCK: - # * dynamic input quantization - # * calculation of input scales is fused in the kernel. + # * hidden_states are bf16 + # * dyanmic input quant fused into kernels quant_scales = [ self.w1_scale, self.w2_scale, From 844a65a772fcbab861330c61bcb51e6e8a8e3eb0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:28:59 +0000 Subject: [PATCH 14/55] remove Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 6938871efc0b..1959ebaee3c2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -946,7 +946,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - # register_moe_scaling_factors(layer) w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( _convert_moe_scaling_factors( w13_scale=layer.w13_weight_scale, From 24a0302148d5694740605e7286a9d3dade064b47 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 16:43:03 +0000 Subject: [PATCH 15/55] rename method Signed-off-by: Robert Shaw --- .../layers/quantization/modelopt.py | 7 +++-- .../quantization/utils/flashinfer_utils.py | 28 ++++++++----------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1959ebaee3c2..1d335bb6711b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -45,10 +45,9 @@ ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, - # register_moe_scaling_factors, - _convert_moe_scaling_factors, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, + convert_flashinfer_fp8_moe_scales, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, @@ -946,8 +945,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + + # NOTE(rob): we should not do this if we are using fp8 block. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - _convert_moe_scaling_factors( + convert_flashinfer_fp8_moe_scales( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index bd3dd7971cbd..9e2b29ea66c2 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -155,6 +155,18 @@ def apply_flashinfer_per_tensor_scale_fp8( ) +def convert_flashinfer_fp8_moe_scales( + w13_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + w13_scale = (w13_scale * w13_input_scale).squeeze() + w2_scale = (w2_scale * w2_input_scale).squeeze() + w2_input_scale = 1.0 / w2_input_scale + return w13_scale, w13_input_scale, w2_scale, w2_input_scale + + def get_moe_scaling_factors( input_scale: torch.Tensor, gemm1_weights_scale: torch.Tensor, @@ -168,22 +180,6 @@ def get_moe_scaling_factors( return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar -def _convert_moe_scaling_factors( - w13_scale: torch.Tensor, - w13_input_scale: torch.Tensor, - w2_scale: torch.Tensor, - w2_input_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - w13_scale = ( - w13_scale * w13_input_scale - ).squeeze() # output1_scales_gate_scalar, g1_alphas - w2_scale = ( - w2_scale * w2_input_scale - ).squeeze() # output2_scales_scalar, g2_alphas - w2_input_scale = 1.0 / w2_input_scale # w2_input_scale_inv, a2_gscale - return w13_scale, w13_input_scale, w2_scale, w2_input_scale - - def register_moe_scaling_factors(layer: torch.nn.Module) -> None: output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( layer.w13_input_scale, From 7edf70fc2c9195046fae448e6f52795b66e2006e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 17:21:24 +0000 Subject: [PATCH 16/55] stash trtllm fix Signed-off-by: Robert Shaw --- .../layers/quantization/utils/flashinfer_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9e2b29ea66c2..357bbbf15678 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -140,9 +140,12 @@ def apply_flashinfer_per_tensor_scale_fp8( input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, gemm2_weights=layer.w2_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - output2_scales_scalar=layer.output2_scales_scalar, + # output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_scalar=(layer.w2_input_scale * layer.w13_scale), + # output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + output1_scales_gate_scalar=layer.w13_scale, + # output2_scales_scalar=layer.output2_scales_scalar, + output2_scales_scalar=layer.w2_scale, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, @@ -164,6 +167,7 @@ def convert_flashinfer_fp8_moe_scales( w13_scale = (w13_scale * w13_input_scale).squeeze() w2_scale = (w2_scale * w2_input_scale).squeeze() w2_input_scale = 1.0 / w2_input_scale + return w13_scale, w13_input_scale, w2_scale, w2_input_scale From 9d994a68e0e2aae80a6b50675ea72a2ef62490c0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 12:31:25 -0500 Subject: [PATCH 17/55] stash changes Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 4 ++-- .../layers/quantization/utils/flashinfer_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1d335bb6711b..3b6e39ac329d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -47,7 +47,7 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_flashinfer_fp8_moe_scales, + convert_flashinfer_fp8_moe_per_tensor_scales, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, @@ -948,7 +948,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # NOTE(rob): we should not do this if we are using fp8 block. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_scales( + convert_flashinfer_fp8_moe_per_tensor_scales( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9e2b29ea66c2..739fd96cf45f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -155,7 +155,7 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -def convert_flashinfer_fp8_moe_scales( +def convert_flashinfer_fp8_moe_per_tensor_scales( w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, From c9a7e5bb08715b7b269245da248a743af4ec7d37 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:34:06 -0500 Subject: [PATCH 18/55] updated Signed-off-by: Robert Shaw --- .../quantization/utils/flashinfer_utils.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index ac9e07125a3d..b959d3a8e4a8 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -118,16 +118,15 @@ def apply_flashinfer_per_tensor_scale_fp8( import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_scalar to be initialized" - ) - assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_gate_scalar to be initialized" - ) - assert layer.output1_scales_scalar is not None, ( - "Expected output2_scales_scalar to be initialized" - ) - + # assert layer.output1_scales_scalar is not None, ( + # "Expected output1_scales_scalar to be initialized" + # ) + # assert layer.output1_scales_scalar is not None, ( + # "Expected output1_scales_gate_scalar to be initialized" + # ) + # assert layer.output1_scales_scalar is not None, ( + # "Expected output2_scales_scalar to be initialized" + # ) from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( @@ -141,11 +140,11 @@ def apply_flashinfer_per_tensor_scale_fp8( gemm1_weights=layer.w13_weight, gemm2_weights=layer.w2_weight, # output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_scalar=(layer.w2_input_scale * layer.w13_scale), + output1_scales_scalar=(layer.w2_input_scale * layer.w13_weight_scale), # output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - output1_scales_gate_scalar=layer.w13_scale, + output1_scales_gate_scalar=layer.w13_weight_scale, # output2_scales_scalar=layer.output2_scales_scalar, - output2_scales_scalar=layer.w2_scale, + output2_scales_scalar=layer.w2_weight_scale, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, From 2408ad2def3c761f4855524f7926fa8a836a754a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:37:43 -0500 Subject: [PATCH 19/55] make trtllm work Signed-off-by: Robert Shaw --- .../layers/quantization/utils/flashinfer_utils.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index b959d3a8e4a8..19a99188a430 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -116,17 +116,6 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - # assert layer.output1_scales_scalar is not None, ( - # "Expected output1_scales_scalar to be initialized" - # ) - # assert layer.output1_scales_scalar is not None, ( - # "Expected output1_scales_gate_scalar to be initialized" - # ) - # assert layer.output1_scales_scalar is not None, ( - # "Expected output2_scales_scalar to be initialized" - # ) from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( From 96ff599c6be15961b930c3bdca84b17d9da35994 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:39:36 -0500 Subject: [PATCH 20/55] ad back import Signed-off-by: Robert Shaw --- .../layers/quantization/utils/flashinfer_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 19a99188a430..b946fea91fb0 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -116,11 +116,13 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" ) + return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits=router_logits, routing_bias=routing_bias, @@ -128,11 +130,8 @@ def apply_flashinfer_per_tensor_scale_fp8( input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, gemm2_weights=layer.w2_weight, - # output1_scales_scalar=layer.output1_scales_scalar, output1_scales_scalar=(layer.w2_input_scale * layer.w13_weight_scale), - # output1_scales_gate_scalar=layer.output1_scales_gate_scalar, output1_scales_gate_scalar=layer.w13_weight_scale, - # output2_scales_scalar=layer.output2_scales_scalar, output2_scales_scalar=layer.w2_weight_scale, num_experts=global_num_experts, top_k=top_k, From f9a4724a8ff7e9bc41033c1a4c4348223b80fd10 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:53:54 -0500 Subject: [PATCH 21/55] update comments Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 16 +++++- .../quantization/utils/flashinfer_utils.py | 57 +++++++------------ 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index bf4ef2d30466..a7de62f40053 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -14,8 +14,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, + convert_flashinfer_fp8_moe_per_tensor_scales, flashinfer_cutlass_moe_fp8, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, ) @@ -122,7 +122,19 @@ def make_moe_tensors_8bit( all2all_backend="naive", ) - register_moe_scaling_factors(layer) + # Convert scales to flashinfer format. + w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( + convert_flashinfer_fp8_moe_per_tensor_scales( + w13_scale=layer.w13_weight_scale, + w13_input_scale=layer.w13_input_scale, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + ) + ) + layer.w13_weight_scale.data = w13_weight_scale + layer.w13_input_scale.data = w13_input_scale + layer.w2_weight_scale.data = w2_weight_scale + layer.w2_input_scale.data = w2_input_scale # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index b946fea91fb0..899202e97bf2 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -151,6 +151,27 @@ def convert_flashinfer_fp8_moe_per_tensor_scales( w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + On disk, we have an explicit scale for each weight / input: + * w13_scale + * w13_input_scale + * w2_scale + * w2_input_scale + + Flashinfer kernels instead expect: + * output1_scales_gate_scalar = w13_scale * w13_input_scale + * output2_scales_scalar = w2_scale * w2_input_scale + * w2_input_scale_inv = 1.0 / w2_input_scale + + So, we put: + * w13_scale = output1_scales_gate_scalar + * w2_scale = output2_scales_scalar + * w2_input_scale = w2_input_scale_inv + + NOTE(rob): we should consider updating moe_quant_config to make + this more explicit in the future rather than smuggling it + through the existing names. + """ w13_scale = (w13_scale * w13_input_scale).squeeze() w2_scale = (w2_scale * w2_input_scale).squeeze() w2_input_scale = 1.0 / w2_input_scale @@ -158,42 +179,6 @@ def convert_flashinfer_fp8_moe_per_tensor_scales( return w13_scale, w13_input_scale, w2_scale, w2_input_scale -def get_moe_scaling_factors( - input_scale: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, - gemm2_weights_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale) - output1_scales_gate_scalar = gemm1_weights_scale * input_scale - output2_scales_scalar = activation_scale * gemm2_weights_scale - - return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar - - -def register_moe_scaling_factors(layer: torch.nn.Module) -> None: - output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( - layer.w13_input_scale, - layer.w13_weight_scale, - layer.w2_input_scale, - layer.w2_weight_scale, - ) - layer.register_parameter( - "output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False) - ) - layer.register_parameter( - "output1_scales_gate_scalar", - torch.nn.Parameter(output1_gate_scales, requires_grad=False), - ) - layer.register_parameter( - "output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False) - ) - layer.register_parameter( - "w2_input_scale_inv", - torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False), - ) - - def build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False ) -> mk.FusedMoEPrepareAndFinalize: From e8831f93dee08ed24be501d15753cf53c9c74d09 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 15:59:02 -0500 Subject: [PATCH 22/55] apply changes to fp8.py Signed-off-by: Robert Shaw --- .../model_executor/layers/quantization/fp8.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 08e1f4d444ee..16c8097edccc 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,8 +49,8 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, + convert_flashinfer_fp8_moe_per_tensor_scales, get_flashinfer_moe_backend, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -893,6 +893,8 @@ def _convert_weights_to_kernel_format( w2_weight: torch.Tensor, w13_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_input_scale: torch.Tensor | None, ) -> None: if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: assert self.block_quant @@ -937,9 +939,15 @@ def _convert_weights_to_kernel_format( if self.block_quant: w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - # TODO(rob): this function is a hack that renames the scaling - # factors in the Module. This is a hack we should clean up. - register_moe_scaling_factors(layer) + w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( + convert_flashinfer_fp8_moe_per_tensor_scales( + w13_scale=w13_weight_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_weight_scale, + w2_input_scale=w2_input_scale, + ) + ) + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) elif self.fp8_backend == Fp8MoeBackend.AITER: @@ -1093,7 +1101,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # Shuffle weights into the runtime format. self._convert_weights_to_kernel_format( - layer, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale + layer=layer, + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=w13_weight_scale, + w2_weight_scale=w2_weight_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, ) # Setup modular kernel for TP case. @@ -1427,7 +1441,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # Shuffle weights into the runtime format. self._convert_weights_to_kernel_format( - layer, w13_weight, w2_weight, layer.w13_weight_scale, layer.w2_weight_scale + layer=layer, + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=layer.w13_weight_scale, + w2_weight_scale=layer.w2_weight_scale, + w13_input_scale=None, + w2_input_scale=None, ) # Setup modular kernel for TP case. From f8f9a3305e6640cf19ad8d5db2a10acc3854e35d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 16:00:50 -0500 Subject: [PATCH 23/55] nit Signed-off-by: Robert Shaw --- .../model_executor/layers/quantization/utils/flashinfer_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 899202e97bf2..96f7f901f980 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -122,7 +122,6 @@ def apply_flashinfer_per_tensor_scale_fp8( assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" ) - return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits=router_logits, routing_bias=routing_bias, From 59f97a6170ac761446190b4147302387a6bdb92a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 16:01:48 -0500 Subject: [PATCH 24/55] revert unneeded assets Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 16c8097edccc..c332ff1ecb55 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -745,18 +745,6 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "FlashInfer CUTLASS FP8 MoE backend only supports block " "size [128, 128]." ) - if not self.block_quant: - if layer.renormalize or layer.custom_routing_function is not None: - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend does custom routing " - f"function or renormalization, but got {layer.renormalize} and " - f"{layer.custom_routing_function}." - ) - if layer.scoring_func != "sigmoid": - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports " - f"'sigmoid' scoring function, but got {layer.scoring_func}." - ) if layer.activation != "silu": raise NotImplementedError( "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " From 113e472c9bd8c920638401663b9182d312281f59 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 21:06:10 +0000 Subject: [PATCH 25/55] rename Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 4 ++-- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- vllm/model_executor/layers/quantization/modelopt.py | 4 ++-- .../layers/quantization/utils/flashinfer_utils.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index a7de62f40053..7250ca7b2fe9 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, - convert_flashinfer_fp8_moe_per_tensor_scales, + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, @@ -124,7 +124,7 @@ def make_moe_tensors_8bit( # Convert scales to flashinfer format. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales( + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c332ff1ecb55..c4f4163a1f95 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,7 +49,7 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_flashinfer_fp8_moe_per_tensor_scales, + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, get_flashinfer_moe_backend, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, @@ -928,7 +928,7 @@ def _convert_weights_to_kernel_format( w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales( + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( w13_scale=w13_weight_scale, w13_input_scale=w13_input_scale, w2_scale=w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3b6e39ac329d..ae89d1d475fa 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -47,7 +47,7 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_flashinfer_fp8_moe_per_tensor_scales, + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, @@ -948,7 +948,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # NOTE(rob): we should not do this if we are using fp8 block. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales( + convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 96f7f901f980..e52d8f88ed65 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -144,7 +144,7 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -def convert_flashinfer_fp8_moe_per_tensor_scales( +def convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, From a98a38016a7acaf32fd0a72f8ffbbad7d5ea9b15 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 16:10:30 -0500 Subject: [PATCH 26/55] update comment Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_prepare_finalize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 1a4f6598c296..59e902272d57 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -188,7 +188,6 @@ def prepare( if not self.use_deepseek_fp8_block_scale: a1q, a1q_scale = moe_kernel_quantize_input( a1, - # TODO(rob): look at the signature of the nvfp4 case. quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, From df5035c160242826a51348c77ac420caa2099af7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 16:11:36 -0500 Subject: [PATCH 27/55] naming Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 4 ++-- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- vllm/model_executor/layers/quantization/modelopt.py | 4 ++-- .../layers/quantization/utils/flashinfer_utils.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 7250ca7b2fe9..26688086d3af 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, + convert_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, @@ -124,7 +124,7 @@ def make_moe_tensors_8bit( # Convert scales to flashinfer format. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( + convert_fp8_moe_per_tensor_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c4f4163a1f95..efc3859a0bac 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,7 +49,7 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, + convert_fp8_moe_per_tensor_scales_for_fi, get_flashinfer_moe_backend, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, @@ -928,7 +928,7 @@ def _convert_weights_to_kernel_format( w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( + convert_fp8_moe_per_tensor_scales_for_fi( w13_scale=w13_weight_scale, w13_input_scale=w13_input_scale, w2_scale=w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ae89d1d475fa..84610059952f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -47,7 +47,7 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi, + convert_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, @@ -948,7 +948,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # NOTE(rob): we should not do this if we are using fp8 block. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( + convert_fp8_moe_per_tensor_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index e52d8f88ed65..a29ec13ce12a 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -144,7 +144,7 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -def convert_flashinfer_fp8_moe_per_tensor_scales_for_fi( +def convert_fp8_moe_per_tensor_scales_for_fi( w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, From 36784020424a0cebec0c04a7c7d816ef682519f6 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 1 Jan 2026 21:16:56 +0000 Subject: [PATCH 28/55] add back check to prevent mixtral Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index efc3859a0bac..55b52e73cf3d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -745,6 +745,18 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "FlashInfer CUTLASS FP8 MoE backend only supports block " "size [128, 128]." ) + if not self.block_quant: + if layer.renormalize or layer.custom_routing_function is not None: + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend does custom routing " + f"function or renormalization, but got {layer.renormalize} and " + f"{layer.custom_routing_function}." + ) + if layer.scoring_func != "sigmoid": + raise NotImplementedError( + "FlashInfer CUTLASS FP8 MoE backend only supports " + f"'sigmoid' scoring function, but got {layer.scoring_func}." + ) if layer.activation != "silu": raise NotImplementedError( "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " From c30d40492d86f150140f7dcd7d9a1648b2a5350b Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 10:27:00 -0500 Subject: [PATCH 29/55] remove delete me Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index b9638bb12948..7b3da3e10ab8 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,7 +432,6 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False - loaded_weight = loaded_weight.to("cuda") # DELETE ME!!!! # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From a285f5e0a182a3bd26d359b53283168e3a81068f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:28:27 +0000 Subject: [PATCH 30/55] update to address pavani's feedback Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 4 +- .../fused_moe/flashinfer_cutlass_moe.py | 18 +++--- .../model_executor/layers/quantization/fp8.py | 47 ++++++++++++--- .../layers/quantization/modelopt.py | 60 +++++++++++++------ .../quantization/utils/flashinfer_utils.py | 46 +++++--------- 5 files changed, 104 insertions(+), 71 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 26688086d3af..e2518ed069f0 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -14,8 +14,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, - convert_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, + make_fp8_moe_alpha_scales_for_fi, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, ) @@ -124,7 +124,7 @@ def make_moe_tensors_8bit( # Convert scales to flashinfer format. w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_fp8_moe_per_tensor_scales_for_fi( + make_fp8_moe_alpha_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 2420b249aab4..99f435e44f32 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -158,19 +158,17 @@ def apply( f"{activation=} missing from {activation_str_to_value_map.keys()=}" ) - # Select quantization arguments based on FP8 format/path + # Select quantization metadata based on FP8 format/path if ( self.quant_dtype == torch.float8_e4m3fn and not self.use_deepseek_fp8_block_scale ): - # FP8 PER TENSOR: - # * scales are computed in process_weights as per below - # * hidden states are quantized + # FP8 PER TENSOR quant_scales = [ - self.w1_scale, # w1_scale = w1_scale * w13_input_scale - self.a2_scale, # a2_scale = 1.0 / a2_input_scale - self.w2_scale, # w2_scale = w2_scale * w2_input_scale - self.a1_scale, # a1_scale = w1_input_scale + self.a1_gscale, # w13_weight_scale * w13_input_scale + self.a2_scale, + self.a2_gscale, # w2_weight_scale * w2_input_scale + self.a1_scale, ] a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 @@ -192,9 +190,7 @@ def apply( fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) elif self.use_deepseek_fp8_block_scale: - # FP8 BLOCK: - # * hidden_states are bf16 - # * dyanmic input quant fused into kernels + # FP8 BLOCK: input quantization is fused into the kernel quant_scales = [ self.w1_scale, self.w2_scale, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 55b52e73cf3d..4af538a4f426 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,8 +49,8 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_fp8_moe_per_tensor_scales_for_fi, get_flashinfer_moe_backend, + make_fp8_moe_alpha_scales_for_fi, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -734,6 +734,20 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled ) + is_flashinfer = self.fp8_backend in [ + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + ] + if ( + is_flashinfer + and not self.block_quant + and self.quant_config.activation_scheme != "static" + ): + raise NotImplementedError( + "FlashInfer FP8 MoE backend only dynamic block quantization" + "or static per tensor activation quantization." + ) + self.marlin_input_dtype = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: @@ -939,17 +953,25 @@ def _convert_weights_to_kernel_format( if self.block_quant: w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_fp8_moe_per_tensor_scales_for_fi( - w13_scale=w13_weight_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_weight_scale, - w2_input_scale=w2_input_scale, - ) + assert self.quant_config.activation_scheme == "static" + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w13_scale=w13_weight_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_weight_scale, + w2_input_scale=w2_input_scale, ) + layer.w2_input_scale = 1.0 / layer.w2_input_scale if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + layer.output1_scales_gate_scalar = g1_alphas + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale + layer.output2_scales_scalar = g2_alphas + else: + assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS + layer.g1_alphas = g1_alphas + layer.g2_alphas = g2_alphas + elif self.fp8_backend == Fp8MoeBackend.AITER: w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( w13_weight, w2_weight @@ -1209,6 +1231,11 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: + # TRTLLM does not use Modular Kernel. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + return None + + # MARLIN uses mixed precision W8A16 config. if self.fp8_backend == Fp8MoeBackend.MARLIN: return fp8_w8a16_moe_quant_config( w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), @@ -1216,12 +1243,16 @@ def get_fused_moe_quant_config( block_shape=self.weight_block_size, ) + # All other backends use W8A8 config. return fp8_w8a8_moe_quant_config( w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.weight_block_size, + # FI CUTLASS uses alphas = weight_s * input_s + g1_alphas=getattr(layer, "g1_alphas", None), + g2_alphas=getattr(layer, "g2_alphas", None), ) @property diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 84610059952f..346235144645 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -47,10 +47,10 @@ FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - convert_fp8_moe_per_tensor_scales_for_fi, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, + make_fp8_moe_alpha_scales_for_fi, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -943,23 +943,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.flashinfer_moe_backend is not None: if self.moe.is_act_and_mul: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - # NOTE(rob): we should not do this if we are using fp8 block. - w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - convert_fp8_moe_per_tensor_scales_for_fi( + # g1_alphas = w13_weight_scale * w13_input_scale + # g2_alphas = w2_weight_scale * w2_input_scale + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( w13_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, w2_scale=layer.w2_weight_scale, w2_input_scale=layer.w2_input_scale, ) - ) - layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) - layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) - layer.w13_input_scale = Parameter(w13_input_scale, requires_grad=False) - layer.w2_input_scale = Parameter(w2_input_scale, requires_grad=False) + # Flashinfer uses the inverse input scale for the activation. + layer.w2_input_scale = 1.0 / layer.w2_input_scale + + # NOTE(rob): these scales do not need to be registered as parameters + # since they are not used for weight (re)-loading. + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + layer.output1_scales_gate_scalar = g1_alphas + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale + layer.output2_scales_scalar = g2_alphas + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + layer.g1_alphas = g1_alphas + layer.g2_alphas = g2_alphas + else: + raise ValueError( + f"Unknown FlashInfer MoE backend: {self.flashinfer_moe_backend}" + ) def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. @@ -1009,15 +1019,27 @@ def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + # TRTLLM does not use modular kernels return None - - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - per_act_token_quant=False, - ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # CUTLASS uses a single dequant alpha = weight_scale * input_scale + assert hasattr(layer, "g1_alphas") and hasattr(layer, "g2_alphas") + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + ) + else: + assert self.flashinfer_moe_backend is None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index a29ec13ce12a..efdea432d620 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -119,6 +119,12 @@ def apply_flashinfer_per_tensor_scale_fp8( import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 from vllm.model_executor.models.llama4 import Llama4MoE + assert ( + hasattr(layer, "output1_scales_scalar") + and hasattr(layer, "output1_scales_gate_scalar") + and hasattr(layer, "output2_scales_scalar") + ) + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" ) @@ -129,9 +135,9 @@ def apply_flashinfer_per_tensor_scale_fp8( input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, gemm2_weights=layer.w2_weight, - output1_scales_scalar=(layer.w2_input_scale * layer.w13_weight_scale), - output1_scales_gate_scalar=layer.w13_weight_scale, - output2_scales_scalar=layer.w2_weight_scale, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + output2_scales_scalar=layer.output2_scales_scalar, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, @@ -144,38 +150,16 @@ def apply_flashinfer_per_tensor_scale_fp8( ) -def convert_fp8_moe_per_tensor_scales_for_fi( +def make_fp8_moe_alpha_scales_for_fi( w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - On disk, we have an explicit scale for each weight / input: - * w13_scale - * w13_input_scale - * w2_scale - * w2_input_scale - - Flashinfer kernels instead expect: - * output1_scales_gate_scalar = w13_scale * w13_input_scale - * output2_scales_scalar = w2_scale * w2_input_scale - * w2_input_scale_inv = 1.0 / w2_input_scale - - So, we put: - * w13_scale = output1_scales_gate_scalar - * w2_scale = output2_scales_scalar - * w2_input_scale = w2_input_scale_inv - - NOTE(rob): we should consider updating moe_quant_config to make - this more explicit in the future rather than smuggling it - through the existing names. - """ - w13_scale = (w13_scale * w13_input_scale).squeeze() - w2_scale = (w2_scale * w2_input_scale).squeeze() - w2_input_scale = 1.0 / w2_input_scale - - return w13_scale, w13_input_scale, w2_scale, w2_input_scale +) -> tuple[torch.Tensor, torch.Tensor]: + g1_alphas = (w13_scale * w13_input_scale).squeeze() + g2_alphas = (w2_scale * w2_input_scale).squeeze() + + return g1_alphas, g2_alphas def build_flashinfer_fp8_cutlass_moe_prepare_finalize( From 2d96161fdd88f7f43ce53e575700541a2696836e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:30:07 +0000 Subject: [PATCH 31/55] reduce LOC change Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_cutlass_moe.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 99f435e44f32..6f983c33ed00 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -163,7 +163,7 @@ def apply( self.quant_dtype == torch.float8_e4m3fn and not self.use_deepseek_fp8_block_scale ): - # FP8 PER TENSOR + # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ self.a1_gscale, # w13_weight_scale * w13_input_scale self.a2_scale, @@ -174,8 +174,10 @@ def apply( fc1_expert_weights = w1 fc2_expert_weights = w2 elif self.quant_dtype == "nvfp4": - # NVFP4 - assert self.w1_scale is not None and self.w2_scale is not None + # Ensure w1_scale and w2_scale are not None before calling view + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ @@ -190,7 +192,7 @@ def apply( fc1_expert_weights = w1.view(torch.long) fc2_expert_weights = w2.view(torch.long) elif self.use_deepseek_fp8_block_scale: - # FP8 BLOCK: input quantization is fused into the kernel + # FP8 block-scale path: provide block-scale weights, omit a1q_scale quant_scales = [ self.w1_scale, self.w2_scale, From a9108728323d3c38ace38a3ef3160cc6aa8da42c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:31:53 +0000 Subject: [PATCH 32/55] fix Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 6f983c33ed00..00f980d19fc0 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -165,9 +165,9 @@ def apply( ): # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ - self.a1_gscale, # w13_weight_scale * w13_input_scale + self.g1_alphas, # w13_weight_scale * w13_input_scale self.a2_scale, - self.a2_gscale, # w2_weight_scale * w2_input_scale + self.g2_alphas, # w2_weight_scale * w2_input_scale self.a1_scale, ] a1q_scale = None # not passing input_sf in fp8 From d2decd6d139e2a35b1c596d74b01f13600907037 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:32:18 +0000 Subject: [PATCH 33/55] reduce loc nit Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 00f980d19fc0..7cfae4ff66c3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -170,6 +170,7 @@ def apply( self.g2_alphas, # w2_weight_scale * w2_input_scale self.a1_scale, ] + a1q_scale = None # not passing input_sf in fp8 fc1_expert_weights = w1 fc2_expert_weights = w2 From 870fc6aba325e2910c58e1457d1028a3d7a337c4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:35:34 +0000 Subject: [PATCH 34/55] clean up fi checking Signed-off-by: Robert Shaw --- .../model_executor/layers/quantization/fp8.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4af538a4f426..aa6a93374245 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -734,20 +734,6 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled ) - is_flashinfer = self.fp8_backend in [ - Fp8MoeBackend.FLASHINFER_TRTLLM, - Fp8MoeBackend.FLASHINFER_CUTLASS, - ] - if ( - is_flashinfer - and not self.block_quant - and self.quant_config.activation_scheme != "static" - ): - raise NotImplementedError( - "FlashInfer FP8 MoE backend only dynamic block quantization" - "or static per tensor activation quantization." - ) - self.marlin_input_dtype = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: @@ -776,6 +762,14 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " "activation function, but got {layer.activation}." ) + dynamic_per_token = ( + not self.block_quant and self.quant_config.activation_scheme != "static" + ) + if self.flashinfer_moe_backend is not None and dynamic_per_token: + raise NotImplementedError( + "FlashInfer FP8 MoE backend does not support dynamic per token " + "activation quantization." + ) def create_weights( self, From 23e79fd8fec73aca3986941ee96ed4b91c267944 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:39:40 +0000 Subject: [PATCH 35/55] updated Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 7b3da3e10ab8..6929234612c2 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,6 +432,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False + loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From 86a0e5c2638c866069dfddb86c459e0370c51615 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:49:15 +0000 Subject: [PATCH 36/55] fix assign float to parameter Signed-off-by: Robert Shaw --- .../model_executor/layers/quantization/fp8.py | 27 ++++++++++++++----- .../layers/quantization/modelopt.py | 11 ++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index aa6a93374245..8bf238aac1e1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -954,12 +954,12 @@ def _convert_weights_to_kernel_format( w2_scale=w2_weight_scale, w2_input_scale=w2_input_scale, ) - layer.w2_input_scale = 1.0 / layer.w2_input_scale + layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas else: assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS @@ -1238,15 +1238,30 @@ def get_fused_moe_quant_config( ) # All other backends use W8A8 config. + g1_alphas = None + g2_alphas = None + a2_scale = layer.w13_input_scale + if ( + self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS + and not self.block_quant + ): + # CUTLASS uses alpha = weight_s * input_s and a2_scale_inv + assert ( + hasattr(layer, "w2_input_scale_inv") + and hasattr(layer, "g1_alphas") + and hasattr(layer, "g2_alphas") + ) + g1_alphas = layer.g1_alphas + g2_alphas = layer.g2_alphas + return fp8_w8a8_moe_quant_config( w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + a2_scale=a2_scale, block_shape=self.weight_block_size, - # FI CUTLASS uses alphas = weight_s * input_s - g1_alphas=getattr(layer, "g1_alphas", None), - g2_alphas=getattr(layer, "g2_alphas", None), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, ) @property diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 346235144645..c74314fa0b4f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -954,15 +954,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Flashinfer uses the inverse input scale for the activation. - layer.w2_input_scale = 1.0 / layer.w2_input_scale + layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale # NOTE(rob): these scales do not need to be registered as parameters # since they are not used for weight (re)-loading. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: layer.g1_alphas = g1_alphas layer.g2_alphas = g2_alphas @@ -1022,13 +1022,12 @@ def get_fused_moe_quant_config( # TRTLLM does not use modular kernels return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - # CUTLASS uses a single dequant alpha = weight_scale * input_scale - assert hasattr(layer, "g1_alphas") and hasattr(layer, "g2_alphas") + # CUTLASS uses alpha = weight_s * input_s and a2_scale_inv return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, + a2_scale=layer.w2_input_scale_inv, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, ) From 7eaa18bea8d153ed32a43e0ac2732f616377fe49 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 19:51:08 +0000 Subject: [PATCH 37/55] updated doc string Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 4266514bc94e..a29e7cf20db4 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -452,9 +452,11 @@ def make( - a1_scale: Optional scale to be used for a1. - a2_scale: Optional scale to be used for a2. - g1_alphas: Optional global quantization scales for w1 (for nvfp4). - per-channel scales for w1 (for W4A8 FP8). + Optional per-channel scales for w1 (for W4A8 FP8). + Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). - per-channel scales for w2 (for W4A8 FP8). + Optional per-channel scales for w2 (for W4A8 FP8). + Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4). - w1_bias: Optional biases for w1 (GPT OSS Triton). From 83a7d9b099cbc338512e0472ae0a8a0c243bbec2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:08:48 +0000 Subject: [PATCH 38/55] fix up Signed-off-by: Robert Shaw --- .../model_executor/layers/fused_moe/config.py | 2 +- .../model_executor/layers/quantization/fp8.py | 49 +++++++++++-------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a29e7cf20db4..2a19b6b436d2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -456,7 +456,7 @@ def make( Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). Optional per-channel scales for w2 (for W4A8 FP8). - Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). + Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4). - w1_bias: Optional biases for w1 (GPT OSS Triton). diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8bf238aac1e1..fa2870416a1c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -954,17 +954,16 @@ def _convert_weights_to_kernel_format( w2_scale=w2_weight_scale, w2_input_scale=w2_input_scale, ) - layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale + # NOTE(rob): the following are not set as nn.Parameters + # they are not needed for weight loading/re-loading. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv + layer.output1_scales_scalar = g1_alphas * ( + 1.0 / layer.w2_input_scale + ) layer.output2_scales_scalar = g2_alphas - else: - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS - layer.g1_alphas = g1_alphas - layer.g2_alphas = g2_alphas elif self.fp8_backend == Fp8MoeBackend.AITER: w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( @@ -1237,27 +1236,37 @@ def get_fused_moe_quant_config( block_shape=self.weight_block_size, ) - # All other backends use W8A8 config. - g1_alphas = None - g2_alphas = None - a2_scale = layer.w13_input_scale + w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + # Flashinfer CUTLASS per-tensor uses single dq scale + # (alpha = w_scale * a_scale) and inverse a2 scale. if ( self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and not self.block_quant ): - # CUTLASS uses alpha = weight_s * input_s and a2_scale_inv - assert ( - hasattr(layer, "w2_input_scale_inv") - and hasattr(layer, "g1_alphas") - and hasattr(layer, "g2_alphas") + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w1_scale, + a1_scale, + w2_scale, + a2_scale, + ) + return fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=(1.0 / a2_scale), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, ) - g1_alphas = layer.g1_alphas - g2_alphas = layer.g2_alphas + # All other backends use normal config. return fp8_w8a8_moe_quant_config( - w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), - w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), - a1_scale=layer.w13_input_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, a2_scale=a2_scale, block_shape=self.weight_block_size, g1_alphas=g1_alphas, From 7300bc587dd330a3c58d6cca0a3280306345dcef Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:09:05 +0000 Subject: [PATCH 39/55] fix up Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fa2870416a1c..8d8a44201625 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1269,8 +1269,6 @@ def get_fused_moe_quant_config( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=self.weight_block_size, - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, ) @property From 344167d042e3a2b3ae90dc795f32762b11dd9d69 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:11:52 +0000 Subject: [PATCH 40/55] fix up configs Signed-off-by: Robert Shaw --- .../layers/quantization/modelopt.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index c74314fa0b4f..3e3b4f3f0791 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -953,9 +953,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_input_scale=layer.w2_input_scale, ) - # Flashinfer uses the inverse input scale for the activation. - layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale - # NOTE(rob): these scales do not need to be registered as parameters # since they are not used for weight (re)-loading. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: @@ -963,9 +960,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - layer.g1_alphas = g1_alphas - layer.g2_alphas = g2_alphas else: raise ValueError( f"Unknown FlashInfer MoE backend: {self.flashinfer_moe_backend}" @@ -1021,15 +1015,23 @@ def get_fused_moe_quant_config( if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: # TRTLLM does not use modular kernels return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - # CUTLASS uses alpha = weight_s * input_s and a2_scale_inv + # Flashinfer CUTLASS per-tensor uses single dq scale + # (alpha = w_scale * a_scale) and inverse a2 scale. + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + layer.w13_weight_scale, + layer.w13_input_scale, + layer.w2_weight_scale, + layer.w2_input_scale, + ) return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale_inv, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, + a2_scale=(1.0 / layer.w2_input_scale), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, ) else: assert self.flashinfer_moe_backend is None From b887c4f3edc32f1bd785a68f9d69f2dd3b95a20a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:15:22 +0000 Subject: [PATCH 41/55] cleanup Signed-off-by: Robert Shaw --- .../layers/quantization/modelopt.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3e3b4f3f0791..b1267c33f124 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -944,19 +944,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.moe.is_act_and_mul: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - # g1_alphas = w13_weight_scale * w13_input_scale - # g2_alphas = w2_weight_scale * w2_input_scale - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=layer.w13_weight_scale, - w13_input_scale=layer.w13_input_scale, - w2_scale=layer.w2_weight_scale, - w2_input_scale=layer.w2_input_scale, - ) - # NOTE(rob): these scales do not need to be registered as parameters # since they are not used for weight (re)-loading. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w13_scale=layer.w13_weight_scale, + w13_input_scale=layer.w13_input_scale, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + ) layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas From d4d42310959fa29e53afeaee47374fd25b7c1545 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:16:15 +0000 Subject: [PATCH 42/55] remove unneeded assert Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b1267c33f124..cb948513685c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -957,10 +957,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas - else: - raise ValueError( - f"Unknown FlashInfer MoE backend: {self.flashinfer_moe_backend}" - ) def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. From 218e69766374d8976b0c52d169666acd50e9fb99 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:18:33 +0000 Subject: [PATCH 43/55] standardize how we add the params Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8d8a44201625..acd1f5554b07 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -947,18 +947,16 @@ def _convert_weights_to_kernel_format( if self.block_quant: w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - assert self.quant_config.activation_scheme == "static" - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=w13_weight_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_weight_scale, - w2_input_scale=w2_input_scale, - ) - # NOTE(rob): the following are not set as nn.Parameters # they are not needed for weight loading/re-loading. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w13_scale=w13_weight_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_weight_scale, + w2_input_scale=w2_input_scale, + ) layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * ( 1.0 / layer.w2_input_scale From b2e3a5002b10f1367af33d533a501887da71cf3d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:22:46 +0000 Subject: [PATCH 44/55] updated Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 15 --------------- vllm/model_executor/layers/quantization/fp8.py | 5 ++--- .../layers/quantization/modelopt.py | 1 + 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index e2518ed069f0..e5384f69f35f 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, - make_fp8_moe_alpha_scales_for_fi, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, ) @@ -122,20 +121,6 @@ def make_moe_tensors_8bit( all2all_backend="naive", ) - # Convert scales to flashinfer format. - w13_weight_scale, w13_input_scale, w2_weight_scale, w2_input_scale = ( - make_fp8_moe_alpha_scales_for_fi( - w13_scale=layer.w13_weight_scale, - w13_input_scale=layer.w13_input_scale, - w2_scale=layer.w2_weight_scale, - w2_input_scale=layer.w2_input_scale, - ) - ) - layer.w13_weight_scale.data = w13_weight_scale - layer.w13_input_scale.data = w13_input_scale - layer.w2_weight_scale.data = w2_weight_scale - layer.w2_input_scale.data = w2_input_scale - # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if reorder: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index acd1f5554b07..dbf888397f3a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -957,10 +957,9 @@ def _convert_weights_to_kernel_format( w2_scale=w2_weight_scale, w2_input_scale=w2_input_scale, ) + layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * ( - 1.0 / layer.w2_input_scale - ) + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas elif self.fp8_backend == Fp8MoeBackend.AITER: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index cb948513685c..694892aded19 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -954,6 +954,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_scale=layer.w2_weight_scale, w2_input_scale=layer.w2_input_scale, ) + layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas From dd3041620b73224be0ca467f56e6ec264e9b8499 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:30:08 +0000 Subject: [PATCH 45/55] updated Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 15 ++++++++++---- .../layers/quantization/modelopt.py | 12 +++++------ .../quantization/utils/flashinfer_utils.py | 20 +++++++++++++++++++ 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index e5384f69f35f..f90cfa135a07 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31, ) @@ -83,7 +84,7 @@ class TestData: @staticmethod def make_moe_tensors_8bit( - m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu" + m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu" ) -> "TestData": is_gated = activation != "relu2_no_mul" @@ -123,8 +124,14 @@ def make_moe_tensors_8bit( # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - if reorder: + if is_trtllm: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + register_scales_for_trtllm_fp8_per_tensor_moe( + layer.w13_weight, + layer.w13_input_scale, + layer.w2_weight_scale, + layer.w2_input_scale, + ) layer.custom_routing_function = Llama4MoE.custom_routing_function layer.intermediate_size_per_partition = n layer.ep_rank = 0 @@ -158,7 +165,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True) + td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) topk_weights, topk_ids = Llama4MoE.custom_routing_function( @@ -223,7 +230,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): td = TestData.make_moe_tensors_8bit( - m, k, n, e, reorder=False, activation=activation + m, k, n, e, is_trtllm=False, activation=activation ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 694892aded19..4c24cf113de4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -51,6 +51,7 @@ get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, make_fp8_moe_alpha_scales_for_fi, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -948,16 +949,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # since they are not used for weight (re)-loading. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=layer.w13_weight_scale, + register_scales_for_trtllm_fp8_per_tensor_moe( + layer=layer, + w13_weight_scale=layer.w13_weight_scale, w13_input_scale=layer.w13_input_scale, - w2_scale=layer.w2_weight_scale, + w2_weight_scale=layer.w2_weight_scale, w2_input_scale=layer.w2_input_scale, ) - layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale - layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv - layer.output2_scales_scalar = g2_alphas def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index efdea432d620..409eabc3845a 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -103,6 +103,26 @@ def rotate_flashinfer_fp8_moe_weights( ) +def register_scales_for_trtllm_fp8_per_tensor_moe( + layer: torch.nn.Module, + w13_weight_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> None: + """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w13_scale=w13_weight_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_weight_scale, + w2_input_scale=w2_input_scale, + ) + layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale + layer.output1_scales_gate_scalar = g1_alphas + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv + layer.output2_scales_scalar = g2_alphas + + def apply_flashinfer_per_tensor_scale_fp8( layer: torch.nn.Module, hidden_states: torch.Tensor, From 3d22ba38fe4d9cc71c8ca34acb4d55e9dffe8c06 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 20:31:09 +0000 Subject: [PATCH 46/55] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dbf888397f3a..dacc066b4ce1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -51,6 +51,7 @@ build_flashinfer_fp8_cutlass_moe_prepare_finalize, get_flashinfer_moe_backend, make_fp8_moe_alpha_scales_for_fi, + register_scales_for_trtllm_fp8_per_tensor_moe, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, @@ -951,16 +952,13 @@ def _convert_weights_to_kernel_format( # they are not needed for weight loading/re-loading. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=w13_weight_scale, + register_scales_for_trtllm_fp8_per_tensor_moe( + layer=layer, + w13_weight_scale=w13_weight, w13_input_scale=w13_input_scale, - w2_scale=w2_weight_scale, + w2_weight_scale=w2_weight, w2_input_scale=w2_input_scale, ) - layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale - layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv - layer.output2_scales_scalar = g2_alphas elif self.fp8_backend == Fp8MoeBackend.AITER: w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( From 56edecab991e94abb51f73f834473e593407c11d Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 15:35:28 -0500 Subject: [PATCH 47/55] a few small nits Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 2 -- .../layers/quantization/utils/flashinfer_utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dacc066b4ce1..c4e99c8edb43 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -948,8 +948,6 @@ def _convert_weights_to_kernel_format( if self.block_quant: w13_weight_scale = swap_w13_to_w31(w13_weight_scale) else: - # NOTE(rob): the following are not set as nn.Parameters - # they are not needed for weight loading/re-loading. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) register_scales_for_trtllm_fp8_per_tensor_moe( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 409eabc3845a..b73c44b3130d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -117,7 +117,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe( w2_scale=w2_weight_scale, w2_input_scale=w2_input_scale, ) - layer.w2_input_scale_inv = 1.0 / layer.w2_input_scale + layer.w2_input_scale_inv = 1.0 / w2_input_scale layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas From e917f5db7517ee29f5dbf244f4d15d0f688ae6a0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 15:37:20 -0500 Subject: [PATCH 48/55] fix tests Signed-off-by: Robert Shaw --- tests/kernels/moe/test_flashinfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index f90cfa135a07..17ec5f003b62 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -127,7 +127,8 @@ def make_moe_tensors_8bit( if is_trtllm: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) register_scales_for_trtllm_fp8_per_tensor_moe( - layer.w13_weight, + layer, + layer.w13_weight_scale, layer.w13_input_scale, layer.w2_weight_scale, layer.w2_input_scale, From 39987f68006cf74d15e9698ac3c46dcbda6a6f8f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 15:44:23 -0500 Subject: [PATCH 49/55] revert the llama weight loading hack Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 6929234612c2..7b3da3e10ab8 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,7 +432,6 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False - loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From 140f447c3a480663f34c00dfb0c87c5df88c166f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 16:06:46 -0500 Subject: [PATCH 50/55] stash Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 7b3da3e10ab8..6929234612c2 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,6 +432,7 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False + loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From de6faa1e5a08f2227fb75464bd538be20c698a58 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 2 Jan 2026 16:47:08 -0500 Subject: [PATCH 51/55] unstash Signed-off-by: Robert Shaw --- vllm/model_executor/models/llama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 6929234612c2..7b3da3e10ab8 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -432,7 +432,6 @@ def load_moe_expert_weights( # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False - loaded_weight = loaded_weight.to("cuda") # If fused is True, the loaded weight is in the layout of: # [num_experts, hidden_in, hidden_out], so we must transpose the last From 12a0638b58d349f3e7fdb6221647c84c62f5bfad Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 5 Jan 2026 14:04:50 +0000 Subject: [PATCH 52/55] update cutlass per tensor Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/config.py | 5 +++-- .../layers/fused_moe/flashinfer_cutlass_moe.py | 4 ++-- vllm/model_executor/layers/quantization/modelopt.py | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 2a19b6b436d2..3f298f7a5ca2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -457,8 +457,9 @@ def make( - g2_alphas: Optional global quantization scales for w2 (for nvfp4). Optional per-channel scales for w2 (for W4A8 FP8). Optional dq scale i.e. w_scale * a_scale (for W8A8 fp8). - - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + - a1_gscale: Optional global quantization scales for a1 (1.0 /a2_scale). + - a2_gscale: Optional global quantization scales for a2 (1.0 /a2_scale). + - w1_bias: Optional biases for w1 (GPT OSS Triton). - w2_bias: Optional biases for w1 (GPT OSS Triton). - w1_zp: Optional w1 zero points for int4/int8 quantization. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 7cfae4ff66c3..5a2ba1601b6f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -166,9 +166,9 @@ def apply( # FP8 per-tensor path: use global alphas/scales; do not pass input_sf quant_scales = [ self.g1_alphas, # w13_weight_scale * w13_input_scale - self.a2_scale, + self.a2_gscale, # 1.0 / w2_input_scale self.g2_alphas, # w2_weight_scale * w2_input_scale - self.a1_scale, + self.a1_scale, # w13_input_scale ] a1q_scale = None # not passing input_sf in fp8 diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5f12c058a005..0657d65741a3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1013,7 +1013,9 @@ def get_fused_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=(1.0 / layer.w2_input_scale), + a2_scale=layer.w2_input_scale, + a1_gscale=(1.0 / layer.w13_input_scale), + a2_gscale=(1.0 / layer.w2_input_scale), g1_alphas=g1_alphas, g2_alphas=g2_alphas, ) From 59db9a987f25e404fc6cc13e93a5388ac2bab0d9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 5 Jan 2026 14:05:36 +0000 Subject: [PATCH 53/55] update cutlass per tensor Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 5a2ba1601b6f..a459b12d2e4c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -168,7 +168,7 @@ def apply( self.g1_alphas, # w13_weight_scale * w13_input_scale self.a2_gscale, # 1.0 / w2_input_scale self.g2_alphas, # w2_weight_scale * w2_input_scale - self.a1_scale, # w13_input_scale + self.a1_scale, ] a1q_scale = None # not passing input_sf in fp8 From d6c4a87b8816a58df590ae58c2cf14d2ea3e7ea3 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 5 Jan 2026 14:08:23 +0000 Subject: [PATCH 54/55] update Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/modelopt.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0657d65741a3..f29b77b3aa3d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -937,8 +937,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.moe.is_act_and_mul: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - # NOTE(rob): these scales do not need to be registered as parameters - # since they are not used for weight (re)-loading. + # NOTE: this adds some attributes used by the trtllm kernel, + # which does not conform to the modular kernels abstraction (yet). if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) register_scales_for_trtllm_fp8_per_tensor_moe( @@ -1001,8 +1001,6 @@ def get_fused_moe_quant_config( return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - # Flashinfer CUTLASS per-tensor uses single dq scale - # (alpha = w_scale * a_scale) and inverse a2 scale. g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( layer.w13_weight_scale, layer.w13_input_scale, From ce913de0a53b342d1475ea1091ee00a31184a15a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 5 Jan 2026 14:28:53 +0000 Subject: [PATCH 55/55] fix trtllm issue Signed-off-by: Robert Shaw --- vllm/model_executor/layers/quantization/fp8.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f47de57ead8a..512a85180b7e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1005,6 +1005,10 @@ def _setup_kernel(self, layer: Module) -> None: AiterExperts, ) + # Flashinfer TRTLLM does not use the modular kernel abstraction. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + return + self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None self.use_inplace = True