From 34bbc4af880cffe53694c0ae712e3ee694e9b7c8 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 20 Jan 2026 16:49:05 +0000 Subject: [PATCH 1/4] fix the semantics, weight descriptor should use the term channel rather than token Signed-off-by: tjtanaa --- .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++-- .../layers/quantization/fbgemm_fp8.py | 4 ++-- .../layers/quantization/modelopt.py | 4 ++-- .../quark/schemes/quark_w8a8_fp8.py | 4 ++-- .../layers/quantization/utils/quant_utils.py | 21 +++++++++---------- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1120202f29fd..34d403a21783 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -29,8 +29,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -56,7 +56,7 @@ DYNAMIC_QUANT: kFp8DynamicTokenSym, } weight_quant_key_mapping = { - QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, QuantizationStrategy.TENSOR: kFp8StaticTensorSym, } logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 45d2e4e33819..ded1780ed5f6 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, kFp8DynamicTokenSym, - kFp8StaticTokenSym, + kFp8StaticChannelSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -96,7 +96,7 @@ def __init__(self, quant_config: FBGEMMFp8Config): self.out_dtype = torch.get_default_dtype() self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, + weight_quant_key=kFp8StaticChannelSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 91dfa03b8b86..15dd3d70bd54 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -82,8 +82,8 @@ cutlass_fp4_supported, is_layer_skipped, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -531,7 +531,7 @@ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, + weight_quant_key=kFp8StaticChannelSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 635b5cf894ef..36255b98591e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -15,8 +15,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -54,7 +54,7 @@ def __init__( kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym ) self.weight_quant_key = ( - kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym + kFp8StaticChannelSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 91fc8760b5ef..f3879e40f49b 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -48,6 +48,7 @@ class GroupShape(_GroupShape): # Aliases for common quantization group shapes PER_TENSOR: ClassVar["GroupShape"] PER_TOKEN: ClassVar["GroupShape"] + PER_CHANNEL: ClassVar["GroupShape"] def is_per_tensor(self) -> bool: return self.row == -1 and self.col == -1 @@ -61,6 +62,7 @@ def is_per_group(self) -> bool: GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +GroupShape.PER_CHANNEL = GroupShape(-1, 1) @dataclass(frozen=True) @@ -77,15 +79,12 @@ class ScaleDesc: group_shape: GroupShape def __str__(self): - group_shape = ( - "per_tensor" - if self.group_shape == GroupShape.PER_TENSOR - else ( - "per_token" - if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape) - ) - ) + d = { + GroupShape.PER_TENSOR: "per_tensor", + GroupShape.PER_TOKEN: "per_token", + GroupShape.PER_CHANNEL: "per_channel", + } + group_shape = d.get(self.group_shape, str(self.group_shape)) return ( f"{fx.graph.dtype_abbrs[self.dtype]}," @@ -123,8 +122,8 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) -kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) -kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) +kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) +kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) From 7ead3787cfbf633c7f934a1e731c36236770c776 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 20 Jan 2026 16:51:29 +0000 Subject: [PATCH 2/4] add is_per_channel method Signed-off-by: tjtanaa --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index f3879e40f49b..1c52b6712b85 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -56,6 +56,9 @@ def is_per_tensor(self) -> bool: def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 + def is_per_channel(self) -> bool: + return self.row == -1 and self.col == 1 + def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 From 8783e9de289289df18e55e1f03ce105a3836daf4 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Thu, 22 Jan 2026 08:14:51 +0000 Subject: [PATCH 3/4] fix is_per_channel check condition and add documentation Signed-off-by: tjtanaa --- .../layers/quantization/utils/quant_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 68eda65f9a57..0d09caabfa19 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -57,15 +57,19 @@ def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 def is_per_channel(self) -> bool: - return self.row == -1 and self.col == 1 + return self.row == 1 and self.col == -1 def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) +# Input shape is in (M, K) +# Descriptor for weights that are quantized per token GroupShape.PER_TOKEN = GroupShape(1, -1) -GroupShape.PER_CHANNEL = GroupShape(-1, 1) +# Weight shape is in (N, K) +# Descriptor for weights that are quantized per output channel +GroupShape.PER_CHANNEL = GroupShape(1, -1) @dataclass(frozen=True) @@ -127,9 +131,6 @@ def __str__(self): kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) -kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) -kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True) - kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) From 834a467dea91fb14104b127ce2612ca3fd0d687f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Thu, 22 Jan 2026 08:16:39 +0000 Subject: [PATCH 4/4] add back quant type Signed-off-by: tjtanaa --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 0d09caabfa19..22fb73bab503 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -128,6 +128,9 @@ def __str__(self): kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) +kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN) +kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True) + kStaticChannelScale = ScaleDesc(torch.float32, True, GroupShape.PER_CHANNEL) kFp8StaticChannelSym = QuantKey(FP8_DTYPE, kStaticChannelScale, symmetric=True)