Skip to content

Commit

Permalink
Fix the impl for to for int4 weight only use case (pytorch#522)
Browse files Browse the repository at this point in the history
Summary:
Note that we can do the following right now:
* initialize and quantize the model with int4_weight_only quant in cpu
* move the model to cuda

we'll enable this in a separate PR

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 17, 2024
1 parent e31b575 commit 8bdfd0d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
19 changes: 18 additions & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_quantized_tensor_subclass_save_load(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_model_to_device(self):
def test_int8wo_quantized_model_to_device(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
Expand All @@ -637,6 +637,23 @@ def test_quantized_model_to_device(self):
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
def test_int4wo_quantized_model_to_device(self):
# TODO: change initial model to "cpu"
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load_map_location(self):
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def from_plain(
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"):
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
return self.__class__(
self.packed_weight.to(device),
Expand Down

0 comments on commit 8bdfd0d

Please sign in to comment.