From 47bb9ca7fd6a13463bba858299318407a5f3e954 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 17 Jun 2024 11:33:22 -0700 Subject: [PATCH] Enable a test for loading state_dict with tensor subclasses Summary: Enabling the use case ofsaving a quantized state dict from our new quantization API and then load that to a non-quantized model, see test for more details Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_save_load Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 20 ++++++++++++++++++++ torchao/dtypes/aqt.py | 10 +++++----- torchao/quantization/README.md | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 36d55400bc..8a78f4a73e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -617,5 +617,25 @@ def apply_my_dtype(weight): quantize(m, "my_dtype") m(*example_inputs) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_save_load(self): + m = ToyLinearModel().eval().to(torch.bfloat16) + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16) + + m = quantize(m, "int8_weight_only") + ref = m(*example_inputs) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + + m_copy.load_state_dict(state_dict, assign=True) + + res = m_copy(*example_inputs) + self.assertEqual(res, ref) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 4d83c92050..83c7d22fb4 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -229,7 +229,7 @@ def from_float( ) @property - def layout(self) -> str: + def extended_layout(self) -> str: return self.layout_tensor.extended_layout @classmethod @@ -555,8 +555,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): is_cuda and input_is_int8 and input_tensor.dtype == weight_qtensor.dtype and - input_tensor.layout == "plain" and - weight_qtensor.layout == "plain" + input_tensor.extended_layout == "plain" and + weight_qtensor.extended_layout == "plain" ): # # 1. do the matrix form of dot(X_i, W_j) @@ -598,7 +598,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and - weight_qtensor.layout == "tensor_core_tiled" + weight_qtensor.extended_layout == "tensor_core_tiled" ): assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" assert input_tensor.shape[-1] == weight_qtensor.shape[1], ( @@ -641,7 +641,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_qtensor.block_size[0] == 1 and weight_qtensor.block_size[1] == weight_qtensor.shape[1] and weight_qtensor.zero_point_domain == ZeroPointDomain.INT and - weight_qtensor.layout == "plain" + weight_qtensor.extended_layout == "plain" ): # TODO: enable cpu and mps efficient path # per channel int8 weight only quantizated mm diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 1a804fe806..28f4ea71c5 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -171,7 +171,7 @@ from torchao.quantization import quant_api # for torch 2.4+ from torchao.quantization.quant_api import quantize -quantize(model, "int8_dynamic_quant") +quantize(model, "int8_dynamic") # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors