diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4d5a2c511c..c21f3a38be 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -23,7 +23,7 @@ int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, - quantize, + quantize_, _replace_with_custom_fn_if_matches_filter, ) # APIs to be deprecated (used for torch 2.2.2 and 2.3) @@ -98,21 +98,21 @@ def _int8wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_weight_only(), set_inductor_config=False) + quantize_(mod, int8_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_woqtensors(mod) def _int8da_int8w_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): if TORCH_VERSION_AFTER_2_4: - quantize(mod, int4_weight_only(), set_inductor_config=False) + quantize_(mod, int4_weight_only(), set_inductor_config=False) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod) @@ -127,8 +127,8 @@ def _int4wo_api(mod): def undo_recommended_configs(): torch._inductor.config.coordinate_descent_tuning = False torch._inductor.config.coordinate_descent_check_all_directions = False - torch._inductor.config.force_fuse_int_mm_with_mul = False - torch._inductor.config.fx_graph_cache = False + torch._inductor.config.force_fuse_int_mm_with_mul = False + torch._inductor.config.fx_graph_cache = False torch._inductor.config.triton.unique_kernel_names = False torch.set_float32_matmul_precision("highest") @@ -844,7 +844,7 @@ def api(mod): kwargs_copy = kwargs.copy() kwargs_copy["group_size"] = groupsize del kwargs_copy["groupsize"] - quantize(mod, int4_weight_only(**kwargs_copy)) + quantize_(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) else: change_linear_weights_to_int4_woqtensors(mod, **kwargs) @@ -865,7 +865,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py index 77eac6f69d..fab2d972b1 100644 --- a/test/prototype/test_quant_llm.py +++ b/test/prototype/test_quant_llm.py @@ -16,7 +16,7 @@ ) from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias): linear = torch.nn.Linear(IC, OC, bias=bias, device=device) fpx_linear = copy.deepcopy(linear) - quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fpx_linear(x) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e8b9d606d7..b137cd22dc 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -31,7 +31,7 @@ Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) -from torchao import quantize +from torchao import quantize_ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, Quantizer, @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - quantize(model, int8_dynamic_activation_int8_weight()) + quantize_(model, int8_dynamic_activation_int8_weight()) return model class ToyLinearModel(torch.nn.Module): @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self): ) m = ToyLinearModel().eval().cpu() def api(model): - model = quantize(model, int8_weight_only()) + quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) api(m) @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") group_size = 32 - m = quantize(m, int4_weight_only(group_size=group_size)) + quantize_(m, int4_weight_only(group_size=group_size)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, int8_dynamic_activation_int8_weight()) + quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16) - m = quantize(m, int8_weight_only()) + quantize_(m, int8_weight_only()) ref = m(*example_inputs) with tempfile.NamedTemporaryFile() as f: torch.save(m.state_dict(), f) diff --git a/torchao/__init__.py b/torchao/__init__.py index 3b5a1b3c0f..104dc5f311 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -30,14 +30,14 @@ from torchao.quantization import ( autoquant, - quantize, + quantize_, ) from . import dtypes __all__ = [ "dtypes", "autoquant", - "quantize", + "quantize_", ] # test-pytorchbot diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 76e7cd9ff2..4765d6a5fc 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.dtypes import to_affine_quantized import copy from torchao.quantization.quant_api import ( - quantize, + quantize_, int4_weight_only, ) @@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -m = quantize(m, int4_weight_only(group_size=group_size)) +quantize_(m, int4_weight_only(group_size=group_size)) # temporary workaround for tensor subclass + torch.compile from torchao.utils import unwrap_tensor_subclass @@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True # for torch 2.4+ from torchao.quantization import quantize, int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.4+ from torchao.quantization import quantize, int8_weight_only -quantize(model, int8_weight_only()) +quantize_(model, int8_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize, int4_weight_only -quantize(model, int4_weight_only()) +quantize_(model, int4_weight_only()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 115062c8f6..a1cf1bf034 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -29,7 +29,7 @@ "quantize_affine", "dequantize_affine", "choose_qprams_affine", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 31ab71f385..3da530b940 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -54,7 +54,7 @@ "Int4WeightOnlyQuantizer", "autoquant", "_get_subclass_inserter", - "quantize", + "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int4_weight_only", @@ -259,8 +259,8 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` +def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): + """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: model (torch.nn.Module): input model @@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens import torch import torch.nn as nn - from torchao import quantize + from torchao import quantize_ # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) @@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens from torchao.quantization.quant_api import int4_weight_only m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, int4_weight_only(group_size=32)) + quantize_(m, int4_weight_only(group_size=32)) # 2. write your own new apply_tensor_subclass # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor @@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - m = quantize(m, apply_weight_quant, filter_fn) + quantize_(m, apply_weight_quant, filter_fn) """ if set_inductor_config: @@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: _get_linear_subclass_inserter(apply_tensor_subclass), _is_linear if filter_fn is None else filter_fn, ) - return model + def int8_dynamic_activation_int4_weight(group_size=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 07e0118d20..a082cfe53a 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -19,9 +19,9 @@ # for APIs for earlier torch version and other quantization techniques # for torch 2.4+ -from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import quantize_ from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight -quantize(model, int8_dynamic_activation_int8_weight()) +quantize_(model, int8_dynamic_activation_int8_weight()) ## Quantization code - end ## compilation configs