diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1120202f29fd..34d403a21783 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -29,8 +29,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -56,7 +56,7 @@ DYNAMIC_QUANT: kFp8DynamicTokenSym, } weight_quant_key_mapping = { - QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, + QuantizationStrategy.CHANNEL: kFp8StaticChannelSym, QuantizationStrategy.TENSOR: kFp8StaticTensorSym, } logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 45d2e4e33819..ded1780ed5f6 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, kFp8DynamicTokenSym, - kFp8StaticTokenSym, + kFp8StaticChannelSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -96,7 +96,7 @@ def __init__(self, quant_config: FBGEMMFp8Config): self.out_dtype = torch.get_default_dtype() self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, + weight_quant_key=kFp8StaticChannelSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0de9cb88da1e..497382dc282b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -81,8 +81,8 @@ cutlass_fp4_supported, is_layer_skipped, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, kNvfp4Dynamic, kNvfp4Static, swizzle_blockscale, @@ -532,7 +532,7 @@ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, - weight_quant_key=kFp8StaticTokenSym, + weight_quant_key=kFp8StaticChannelSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 635b5cf894ef..36255b98591e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -15,8 +15,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8DynamicTokenSym, + kFp8StaticChannelSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -54,7 +54,7 @@ def __init__( kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym ) self.weight_quant_key = ( - kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym + kFp8StaticChannelSym if per_token_weight else kFp8StaticTensorSym ) self.out_dtype = torch.get_default_dtype() diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bc7458444412..22fb73bab503 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -57,15 +57,19 @@ def is_per_token(self) -> bool: return self.row == 1 and self.col == -1 def is_per_channel(self) -> bool: - return self.row == -1 and self.col == 1 + return self.row == 1 and self.col == -1 def is_per_group(self) -> bool: return self.row == 1 and self.col >= 1 GroupShape.PER_TENSOR = GroupShape(-1, -1) +# Input shape is in (M, K) +# Descriptor for weights that are quantized per token GroupShape.PER_TOKEN = GroupShape(1, -1) -GroupShape.PER_CHANNEL = GroupShape(-1, 1) +# Weight shape is in (N, K) +# Descriptor for weights that are quantized per output channel +GroupShape.PER_CHANNEL = GroupShape(1, -1) @dataclass(frozen=True)