Skip to content
27 changes: 19 additions & 8 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
Expand Down Expand Up @@ -122,15 +122,26 @@ def _fp8_quantize(
is provided, the output will be blocked.
"""
if block_shape is None:
# TODO(luka): use QuantFP8 custom op
# https://github.com/vllm-project/vllm/issues/20711
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
if per_act_token:
group_shape = GroupShape.PER_TOKEN
else:
group_shape = GroupShape.PER_TENSOR

quant_op = QuantFP8(static=(A_scale is not None),
group_shape=group_shape)
A, A_scale = quant_op(A, A_scale)
else:
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)

group_shape = GroupShape(1, block_k)
quant_op = QuantFP8(
static=False, # Group quantization is always dynamic
group_shape=group_shape,
column_major_scales=False # Use row-major for MoE
)
A, A_scale = quant_op(A)
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)

return A, A_scale
Expand Down
75 changes: 54 additions & 21 deletions vllm/model_executor/layers/quantization/input_quant_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,78 @@
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to per-tensor or per-token FP8.
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
This CustomOp supports both static and dynamic quantization.
"""

def __init__(self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None):
num_token_padding: Optional[int] = None,
column_major_scales: bool = False):
"""

:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR)
:param num_token_padding: Pad the token dimension of output to this size
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
column major format
"""
super().__init__()
self.num_token_padding = num_token_padding
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales

self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert that column_major_scales is False if non group?

assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN

def _quantize_group(self,
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
return per_token_group_quant_fp8(
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE)

def _compute_dynamic_scale(
self, x: torch.Tensor,
scale_ub: Optional[torch.Tensor]) -> torch.Tensor:
if self.group_shape == GroupShape.PER_TOKEN:
x_max, _ = x.abs().max(dim=-1)
x_max = x_max.unsqueeze(-1).to(torch.float32)
if scale_ub is not None:
x_max = x_max.clamp(max=scale_ub)
else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)

scale = x_max / _FP8_MAX
return scale.clamp(min=_FP8_MIN_SCALING_FACTOR)

def forward_cuda(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
return self._quantize_group(x)

assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)

return ops.scaled_fp8_quant(
x,
scale,
Expand All @@ -70,22 +108,17 @@ def forward_native(
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
return self._quantize_group(x)

assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)

if scale is None:
if self.group_shape == GroupShape.PER_TOKEN:
x_max, _ = x.abs().max(dim=-1)
x_max = x_max.unsqueeze(-1).to(torch.float32)
if scale_ub is not None:
x_max = x_max.clamp(max=scale_ub)
else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)

scale = x_max / _FP8_MAX
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
scale = self._compute_dynamic_scale(x, scale_ub)

# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class GroupShape(_GroupShape):
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']

def is_per_tensor(self) -> bool:
return self.row == -1 and self.col == -1

def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1

def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1


GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
Expand Down
Loading