Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tpu_inference/layers/vllm/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoERouter
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoERouter)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down Expand Up @@ -242,7 +243,12 @@ def __init__(self,
layer: torch.nn.Module,
mesh: Mesh,
ep_axis_name: str = "model"):
super().__init__(quant_config, layer)
FusedMoEMethodBase.__init__(self, layer.moe_config)
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = ("weight_scale_inv"
if self.block_quant else "weight_scale")

self.mesh = mesh
self.moe_backend = select_moe_backend(self.moe)
Expand All @@ -251,6 +257,10 @@ def __init__(self,
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
self.extra_backend_kwargs = dict(ep_axis_name=ep_axis_name, )

def get_fused_moe_quant_config(self, layer: torch.nn.Module):
# Override quant config fused moe
pass

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert isinstance(layer, FusedMoE)

Expand Down
Loading