diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a0d2c2d0e6..ab24fc981c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -624,7 +624,7 @@ def test_quantized_tensor_subclass_save_load(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_model_to_device(self): + def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") @@ -637,6 +637,23 @@ def test_quantized_model_to_device(self): cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") + def test_int4wo_quantized_model_to_device(self): + # TODO: change initial model to "cpu" + m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda") + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + + quantize_(m, int4_weight_only()) + ref = m(*example_inputs) + + example_inputs_cuda = (example_inputs[0].to("cuda"),) + m.to(device="cuda") + cuda_res = m(*example_inputs_cuda) + self.assertEqual(cuda_res.cpu(), ref) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load_map_location(self): diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 513ba2a8ca..da5cc7d28b 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -544,7 +544,7 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"): + if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"): raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device),