Skip to content

Commit

Permalink
Enable a test for loading state_dict with tensor subclasses
Browse files Browse the repository at this point in the history
Summary:
Enabling the use case ofsaving a quantized state dict from our new quantization API and then load that to a non-quantized model, see test for more details

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_save_load

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 17, 2024
1 parent eb1511e commit 47bb9ca
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
20 changes: 20 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,5 +617,25 @@ def apply_my_dtype(weight):
quantize(m, "my_dtype")
m(*example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, "int8_weight_only")
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

m_copy.load_state_dict(state_dict, assign=True)

res = m_copy(*example_inputs)
self.assertEqual(res, ref)


if __name__ == "__main__":
unittest.main()
10 changes: 5 additions & 5 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def from_float(
)

@property
def layout(self) -> str:
def extended_layout(self) -> str:
return self.layout_tensor.extended_layout

@classmethod
Expand Down Expand Up @@ -555,8 +555,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda and
input_is_int8 and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.layout == "plain" and
weight_qtensor.layout == "plain"
input_tensor.extended_layout == "plain" and
weight_qtensor.extended_layout == "plain"
):
#
# 1. do the matrix form of dot(X_i, W_j)
Expand Down Expand Up @@ -598,7 +598,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.layout == "tensor_core_tiled"
weight_qtensor.extended_layout == "tensor_core_tiled"
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
Expand Down Expand Up @@ -641,7 +641,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.layout == "plain"
weight_qtensor.extended_layout == "plain"
):
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ from torchao.quantization import quant_api

# for torch 2.4+
from torchao.quantization.quant_api import quantize
quantize(model, "int8_dynamic_quant")
quantize(model, "int8_dynamic")

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand Down

0 comments on commit 47bb9ca

Please sign in to comment.