Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/quantization/test_experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ def test_model_experts_int8_startup(
dtype=dtype,
enforce_eager=True,
quantization="experts_int8",
allow_deprecated_quantization=True,
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
1 change: 0 additions & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
"tpu_int8",
"fbgemm_fp8",
"fp_quant",
"experts_int8",
"petit_nvfp4",
]

Expand Down
154 changes: 49 additions & 105 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from typing import Any

import torch
from torch.nn import Module

from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
Expand All @@ -21,11 +20,14 @@
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.online_moe import (
OnlineMoEMethodBase,
)


class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
"""Online int8 quantization for MoE expert weights.
Linear layers are left unquantized."""

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -60,78 +62,65 @@ def get_quant_method(
return None


class ExpertsInt8MoEMethod(FusedMoEMethodBase):
class ExpertsInt8MoEMethod(OnlineMoEMethodBase):
"""Online int8 MoE quantization. Loads full-precision weights and
quantizes to int8 with per-row scales during model loading."""

def __init__(
self,
quant_config: ExpertsInt8Config,
quant_config: QuantizationConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
int8_dtype = torch.int8
def _quantize_weights(self, layer: Module) -> None:
vmax = torch.iinfo(torch.int8).max

assert "weight_loader" in extra_weight_attrs
weight_loader = extra_weight_attrs["weight_loader"]
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
layer, weight_loader
w13 = torch.empty_like(layer.w13_weight, dtype=torch.int8)
w2 = torch.empty_like(layer.w2_weight, dtype=torch.int8)
w13_scale = torch.zeros(
layer.num_experts,
layer.w13_weight.shape[1],
device=w13.device,
dtype=torch.float32,
)
extra_weight_attrs["weight_loader"] = wrapped_weight_loader

# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=int8_dtype,
),
requires_grad=False,
w2_scale = torch.zeros(
layer.num_experts,
layer.w2_weight.shape[1],
device=w2.device,
dtype=torch.float32,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=int8_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

w13_scale = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_scale", w13_scale)

w2_scale = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_scale", w2_scale)
for expert in range(layer.local_num_experts):
# w13: per-row quantization over hidden_size dim
w = layer.w13_weight[expert, :, :]
scales = w.abs().amax(dim=1) / vmax
q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax)
w13[expert, :, :] = q.to(torch.int8)
w13_scale[expert, :] = scales

# w2: per-row quantization over intermediate_size dim
w = layer.w2_weight[expert, :, :]
scales = w.abs().amax(dim=1) / vmax
q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax)
w2[expert, :, :] = q.to(torch.int8)
w2_scale[expert, :] = scales

# Replace full-precision weights with quantized versions
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w8a16_moe_quant_config(
w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
)

def apply(
Expand All @@ -157,48 +146,3 @@ def apply(
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)

@staticmethod
def quantizing_weight_loader(layer, weight_loader):
def quantize_and_call_weight_loader(
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: int,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
shard_size = layer.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
device = get_tp_group().device
loaded_weight = loaded_weight.to(device)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == "w1":
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == "w3":
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_(
scales[:, 0]
)
# w2, down_proj case: Load into only shard of w2.
elif shard_id == "w2":
scales = quantize_in_place_and_get_scales(loaded_weight[:, shard])
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
else:
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)

return quantize_and_call_weight_loader


def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
vmax = torch.iinfo(torch.int8).max
scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax

weight.div_(scales)
weight.round_()
weight.clamp_(-vmax, vmax)

return scales
Loading
Loading