diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 22edb9c58daf..7cdb135fa077 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -38,6 +38,5 @@ def test_model_experts_int8_startup( dtype=dtype, enforce_eager=True, quantization="experts_int8", - allow_deprecated_quantization=True, ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9aceb3be054d..c0ba8aa600a6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -40,7 +40,6 @@ "tpu_int8", "fbgemm_fp8", "fp_quant", - "experts_int8", "petit_nvfp4", ] diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index d971f3b5b0d2..a1a45a3e1a20 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -4,12 +4,11 @@ from typing import Any import torch +from torch.nn import Module -from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, - FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -21,11 +20,14 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.online_moe import ( + OnlineMoEMethodBase, +) class ExpertsInt8Config(QuantizationConfig): - """Config class for Int8 experts quantization.""" + """Online int8 quantization for MoE expert weights. + Linear layers are left unquantized.""" def __init__(self) -> None: super().__init__() @@ -60,78 +62,65 @@ def get_quant_method( return None -class ExpertsInt8MoEMethod(FusedMoEMethodBase): +class ExpertsInt8MoEMethod(OnlineMoEMethodBase): + """Online int8 MoE quantization. Loads full-precision weights and + quantizes to int8 with per-row scales during model loading.""" + def __init__( self, - quant_config: ExpertsInt8Config, + quant_config: QuantizationConfig, moe: FusedMoEConfig, ): super().__init__(moe) self.quant_config = quant_config - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - int8_dtype = torch.int8 + def _quantize_weights(self, layer: Module) -> None: + vmax = torch.iinfo(torch.int8).max - assert "weight_loader" in extra_weight_attrs - weight_loader = extra_weight_attrs["weight_loader"] - wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( - layer, weight_loader + w13 = torch.empty_like(layer.w13_weight, dtype=torch.int8) + w2 = torch.empty_like(layer.w2_weight, dtype=torch.int8) + w13_scale = torch.zeros( + layer.num_experts, + layer.w13_weight.shape[1], + device=w13.device, + dtype=torch.float32, ) - extra_weight_attrs["weight_loader"] = wrapped_weight_loader - - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=int8_dtype, - ), - requires_grad=False, + w2_scale = torch.zeros( + layer.num_experts, + layer.w2_weight.shape[1], + device=w2.device, + dtype=torch.float32, ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=int8_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - w13_scale = torch.nn.Parameter( - torch.zeros( - num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_scale", w13_scale) - w2_scale = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_scale", w2_scale) + for expert in range(layer.local_num_experts): + # w13: per-row quantization over hidden_size dim + w = layer.w13_weight[expert, :, :] + scales = w.abs().amax(dim=1) / vmax + q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax) + w13[expert, :, :] = q.to(torch.int8) + w13_scale[expert, :] = scales + + # w2: per-row quantization over intermediate_size dim + w = layer.w2_weight[expert, :, :] + scales = w.abs().amax(dim=1) / vmax + q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax) + w2[expert, :, :] = q.to(torch.int8) + w2_scale[expert, :] = scales + + # Replace full-precision weights with quantized versions + layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False) + layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: return int8_w8a16_moe_quant_config( - w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None, ) def apply( @@ -157,48 +146,3 @@ def apply( expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) - - @staticmethod - def quantizing_weight_loader(layer, weight_loader): - def quantize_and_call_weight_loader( - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: int, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - shard_size = layer.intermediate_size_per_partition - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - device = get_tp_group().device - loaded_weight = loaded_weight.to(device) - # w1, gate_proj case: Load into first shard of w13. - if shard_id == "w1": - scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0]) - # w3, up_proj case: Load into second shard of w13. - elif shard_id == "w3": - scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_( - scales[:, 0] - ) - # w2, down_proj case: Load into only shard of w2. - elif shard_id == "w2": - scales = quantize_in_place_and_get_scales(loaded_weight[:, shard]) - layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) - else: - raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") - weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) - - return quantize_and_call_weight_loader - - -def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: - vmax = torch.iinfo(torch.int8).max - scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax - - weight.div_(scales) - weight.round_() - weight.clamp_(-vmax, vmax) - - return scales diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fffcfa5e6329..36e61a338929 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -24,6 +24,7 @@ FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod @@ -44,6 +45,9 @@ QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.online_moe import ( + OnlineMoEMethodBase, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, @@ -197,10 +201,9 @@ def get_quant_method( ): return UnquantizedFusedMoEMethod(layer.moe_config) if self.is_checkpoint_fp8_serialized: - moe_quant_method = Fp8MoEMethod(self, layer) + return Fp8MoEMethod(self, layer) else: - moe_quant_method = Fp8OnlineMoEMethod(self, layer) - return moe_quant_method + return Fp8OnlineMoEMethod(self, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -564,23 +567,17 @@ def process_weights_after_loading(self, layer: Module) -> None: layer._already_called_process_weights_after_loading = True -class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. +class Fp8MoEKernelMixin: + """FP8 backend selection, kernel format conversion, and dispatch for MoE.""" - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ + moe: FusedMoEConfig + moe_quant_config: FusedMoEQuantConfig | None + moe_kernel: mk.FusedMoEKernel | None + is_monolithic: bool - def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): - super().__init__(layer.moe_config) + def _init_fp8_backend(self, quant_config: "Fp8Config") -> None: self.quant_config = quant_config - self.weight_block_size = self.quant_config.weight_block_size + self.weight_block_size = quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" @@ -606,6 +603,142 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): allow_vllm_cutlass=False, ) + def _setup_kernel( + self, + layer: FusedMoE, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_input_scale: torch.Tensor | None, + ) -> None: + # Shuffle weights to runtime format. + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) + + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) + replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + quant_config = make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=self.weight_block_size, + ) + + # Inject biases into the quant config if the model has them + # (e.g. GPT-OSS biased MoE) + if quant_config is not None and self.moe.has_bias: + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + if w13_bias is not None: + quant_config._w1.bias = w13_bias + if w2_bias is not None: + quant_config._w2.bias = w2_bias + + return quant_config + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel " + "initialization logic. This function should not be called." + ) + + @property + def supports_eplb(self) -> bool: + return True + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) + + +class Fp8MoEMethod(Fp8MoEKernelMixin, FusedMoEMethodBase): + """MoE method for loading pre-quantized FP8 checkpoints with static + weight scales and dynamic/static activation scales.""" + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + FusedMoEMethodBase.__init__(self, layer.moe_config) + self._init_fp8_backend(quant_config) + def create_weights( self, layer: Module, @@ -746,47 +879,6 @@ def create_weights( layer.w13_input_scale = None layer.w2_input_scale = None - def _setup_kernel( - self, - layer: FusedMoE, - w13: torch.Tensor, - w2: torch.Tensor, - w13_scale: torch.Tensor, - w2_scale: torch.Tensor, - w13_input_scale: torch.Tensor | None, - w2_input_scale: torch.Tensor | None, - ) -> None: - # Shuffle weights to runtime format. - w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( - fp8_backend=self.fp8_backend, - layer=layer, - w13=w13, - w2=w2, - w13_scale=w13_scale, - w2_scale=w2_scale, - w13_input_scale=w13_input_scale, - w2_input_scale=w2_input_scale, - ) - - # Replace parameters with updated versions. Note that this helper - # function ensures the replacement is compatible with RL weight reloads. - replace_parameter(layer, "w13_weight", w13) - replace_parameter(layer, "w2_weight", w2) - replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) - replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) - - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_kernel = make_fp8_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - fp8_backend=self.fp8_backend, - experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), - shared_experts=layer.shared_experts, - ) - def process_weights_after_loading(self, layer: Module) -> None: # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight @@ -832,107 +924,14 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - def maybe_make_prepare_finalize( - self, - routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalizeModular | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: - w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") - w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") - a1_scale = layer.w13_input_scale - a2_scale = layer.w2_input_scale - - quant_config = make_fp8_moe_quant_config( - fp8_backend=self.fp8_backend, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=self.weight_block_size, - ) - - # Inject biases into the quant config if the model has them - # (e.g. GPT-OSS biased MoE) - if quant_config is not None and self.moe.has_bias: - w13_bias = getattr(layer, "w13_bias", None) - w2_bias = getattr(layer, "w2_bias", None) - if w13_bias is not None: - quant_config._w1.bias = w13_bias - if w2_bias is not None: - quant_config._w2.bias = w2_bias - - return quant_config - - @property - def supports_eplb(self) -> bool: - return True - - def apply_monolithic( - self, - layer: FusedMoE, - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.is_monolithic - assert self.moe_kernel is not None - return self.moe_kernel.apply_monolithic( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - e_score_correction_bias=layer.e_score_correction_bias, - routed_scaling_factor=layer.routed_scaling_factor, - ) - - def apply( - self, - layer: FusedMoE, - x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not self.is_monolithic - assert self.moe_kernel is not None - return self.moe_kernel.apply( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) - -class Fp8OnlineMoEMethod(Fp8MoEMethod): - """MoE method for online FP8 quantization. - Supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - uses_meta_device: bool = True +class Fp8OnlineMoEMethod(Fp8MoEKernelMixin, OnlineMoEMethodBase): + """Online FP8 MoE quantization. Loads full-precision weights and + quantizes to FP8 with per-tensor scales during model loading.""" def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): - super().__init__(quant_config, layer) + OnlineMoEMethodBase.__init__(self, layer.moe_config) + self._init_fp8_backend(quant_config) assert not quant_config.is_checkpoint_fp8_serialized assert quant_config.activation_scheme == "dynamic" assert quant_config.weight_block_size is None @@ -946,44 +945,33 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - device="meta", - dtype=params_dtype, - ), - requires_grad=False, + super().create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - device="meta", # materialized and processed during loading - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) + def _create_extra_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # BIASES (for models like GPT-OSS that have biased MoE) if self.moe.has_bias: w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, - device="meta", # materialized and processed during loading + device="meta", dtype=layer.orig_dtype, ), requires_grad=False, @@ -995,7 +983,7 @@ def create_weights( torch.zeros( num_experts, hidden_size, - device="meta", # materialized and processed during loading + device="meta", dtype=layer.orig_dtype, ), requires_grad=False, @@ -1003,13 +991,7 @@ def create_weights( layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) - initialize_online_processing(layer) - - def process_weights_after_loading(self, layer: Module) -> None: - # TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - + def _quantize_weights(self, layer: Module) -> None: fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) @@ -1039,9 +1021,6 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_input_scale=layer.w2_input_scale, ) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/online_moe.py b/vllm/model_executor/layers/quantization/online_moe.py new file mode 100644 index 000000000000..a0f8ed8a9fcb --- /dev/null +++ b/vllm/model_executor/layers/quantization/online_moe.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod + +import torch + +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.model_loader.reload.layerwise import ( + initialize_online_processing, +) +from vllm.model_executor.utils import set_weight_attrs + + +class OnlineMoEMethodBase(FusedMoEMethodBase): + """Base for MoE methods that load full-precision weights and quantize + them during model loading via the QeRL layerwise processing system.""" + + uses_meta_device: bool = True + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + + # Fused gate_up_proj (column parallel) — full precision on meta device + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + device="meta", + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) — full precision on meta device + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + device="meta", + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Hook for subclasses to add extra params (biases, etc.) + # before initialize_online_processing counts total elements. + self._create_extra_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + initialize_online_processing(layer) + + def _create_extra_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Override to create additional parameters before online processing + initialization. Called after w13/w2 weights are registered but before + ``initialize_online_processing``.""" + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + self._quantize_weights(layer) + layer._already_called_process_weights_after_loading = True + + @abstractmethod + def _quantize_weights(self, layer: torch.nn.Module) -> None: + """Quantize full-precision weights after all experts are loaded.""" + ...