diff --git a/tpu_inference/layers/vllm/quantization/fp8.py b/tpu_inference/layers/vllm/quantization/fp8.py index ae22bf4044..86ecc60d47 100644 --- a/tpu_inference/layers/vllm/quantization/fp8.py +++ b/tpu_inference/layers/vllm/quantization/fp8.py @@ -22,7 +22,8 @@ from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j from vllm.attention.layer import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoERouter +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoERouter) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import \ register_quantization_config @@ -242,7 +243,12 @@ def __init__(self, layer: torch.nn.Module, mesh: Mesh, ep_axis_name: str = "model"): - super().__init__(quant_config, layer) + FusedMoEMethodBase.__init__(self, layer.moe_config) + self.quant_config = quant_config + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant: bool = self.weight_block_size is not None + self.weight_scale_name = ("weight_scale_inv" + if self.block_quant else "weight_scale") self.mesh = mesh self.moe_backend = select_moe_backend(self.moe) @@ -251,6 +257,10 @@ def __init__(self, if self.moe_backend == FusedMoEBackend.FUSED_MOE: self.extra_backend_kwargs = dict(ep_axis_name=ep_axis_name, ) + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + # Override quant config fused moe + pass + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert isinstance(layer, FusedMoE)