From 5013a2761eeb048479215219b291e1ecb51e6d49 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 15 Aug 2025 14:54:31 -0700 Subject: [PATCH] Remove group_size arg in Float8DynamicActivationInt4WeightConfig Summary: Fixes: https://github.com/pytorch/ao/issues/2763 Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py Reviewers: Subscribers: Tasks: Tags: --- .../workflows/int4/test_int4_preshuffled_tensor.py | 2 +- test/quantization/test_qat.py | 2 +- torchao/quantization/quant_api.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index a03970169e..67f8416050 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -33,8 +33,8 @@ version=2, ) +# only 128 group_size is supported FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig( - group_size=128, packing_format="preshuffled", ) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 4035e273c1..4c03442ad7 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1927,7 +1927,7 @@ def test_quantize_api_fp8_int4(self): quantize_(model, QATConfig(Float8DynamicActivationInt4WeightConfig(), step="convert")) """ self._test_quantize_api_against_ptq( - Float8DynamicActivationInt4WeightConfig(group_size=128), + Float8DynamicActivationInt4WeightConfig(), target_prepare_sqnr=15, target_convert_sqnr=float("inf"), ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5d191a7c0e..ac5177a058 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1149,13 +1149,13 @@ def _int4_weight_only_transform( class Float8DynamicActivationInt4WeightConfig(AOBaseConfig): """Configuration for apply float8 dynamic per row quantization and int4 per group weight quantization to linear + (only group_size 128 is supported right now since underlying kernel used only supports 128 + and above and no benefits of making it bigger) Args: - `group_size`: group size for groupwise quantization for weight `packing_format`: how the weight is packed, only preshuffled is supported """ - group_size: int = 128 packing_format: PackingFormat = "preshuffled" @@ -1167,13 +1167,13 @@ def _float8_dynamic_activation_int4_weight_transform( "applying int8 weight only quant requires module to have weight attribute" + " but {module} does not have one" ) - group_size = config.group_size packing_format = config.packing_format assert packing_format == "preshuffled", ( f"only preshuffled packing_format supported right now, got: {packing_format}" ) weight = module.weight + group_size = 128 block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) new_weight = Int4PreshuffledTensor.from_hp( module.weight,