Skip to content

Commit 4077489

Browse files
authored
Enable a test for loading state_dict with tensor subclasses (pytorch#389)
Summary: Enabling the use case ofsaving a quantized state dict from our new quantization API and then load that to a non-quantized model, see test for more details Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_save_load Reviewers: Subscribers: Tasks: Tags:
1 parent b21db88 commit 4077489

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

test/quantization/test_quant_api.py

+20
Original file line numberDiff line numberDiff line change
@@ -617,5 +617,25 @@ def apply_my_dtype(weight):
617617
quantize(m, "my_dtype")
618618
m(*example_inputs)
619619

620+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
621+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
622+
def test_quantized_tensor_subclass_save_load(self):
623+
m = ToyLinearModel().eval().to(torch.bfloat16)
624+
m_copy = copy.deepcopy(m)
625+
example_inputs = m.example_inputs(dtype=torch.bfloat16)
626+
627+
m = quantize(m, "int8_weight_only")
628+
ref = m(*example_inputs)
629+
with tempfile.NamedTemporaryFile() as f:
630+
torch.save(m.state_dict(), f)
631+
f.seek(0)
632+
state_dict = torch.load(f)
633+
634+
m_copy.load_state_dict(state_dict, assign=True)
635+
636+
res = m_copy(*example_inputs)
637+
self.assertEqual(res, ref)
638+
639+
620640
if __name__ == "__main__":
621641
unittest.main()

torchao/dtypes/aqt.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def from_float(
229229
)
230230

231231
@property
232-
def layout(self) -> str:
232+
def extended_layout(self) -> str:
233233
return self.layout_tensor.extended_layout
234234

235235
@classmethod
@@ -555,8 +555,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
555555
is_cuda and
556556
input_is_int8 and
557557
input_tensor.dtype == weight_qtensor.dtype and
558-
input_tensor.layout == "plain" and
559-
weight_qtensor.layout == "plain"
558+
input_tensor.extended_layout == "plain" and
559+
weight_qtensor.extended_layout == "plain"
560560
):
561561
#
562562
# 1. do the matrix form of dot(X_i, W_j)
@@ -598,7 +598,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
598598
weight_qtensor.dtype == torch.bfloat16 and
599599
len(weight_qtensor.shape) == 2 and
600600
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
601-
weight_qtensor.layout == "tensor_core_tiled"
601+
weight_qtensor.extended_layout == "tensor_core_tiled"
602602
):
603603
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
604604
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
@@ -641,7 +641,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
641641
weight_qtensor.block_size[0] == 1 and
642642
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
643643
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
644-
weight_qtensor.layout == "plain"
644+
weight_qtensor.extended_layout == "plain"
645645
):
646646
# TODO: enable cpu and mps efficient path
647647
# per channel int8 weight only quantizated mm

torchao/quantization/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ from torchao.quantization import quant_api
171171

172172
# for torch 2.4+
173173
from torchao.quantization.quant_api import quantize
174-
quantize(model, "int8_dynamic_quant")
174+
quantize(model, "int8_dynamic")
175175

176176
# for torch 2.2.2 and 2.3
177177
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors

0 commit comments

Comments
 (0)