Skip to content

Commit a9907f1

Browse files
authored
Accept device for Int8DynActInt4WeightQuantizer (pytorch#475)
Before we migrate away from `Quantizer` APIs, we want to align the `__init__` arguments between `Int8DynActInt4WeightQuantizer` and `Int4WeightOnlyQuantizer` so it's easier for users to use.
1 parent 32a6503 commit a9907f1

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchao/quantization/GPTQ.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -942,12 +942,14 @@ def __init__(
942942
padding_allowed: bool = False,
943943
precision: torch.dtype = torch.float32,
944944
scales_precision: torch.dtype = torch.float32,
945+
device: torch.device = torch.device("cpu"),
945946
) -> None:
946947
super().__init__()
947948
self.groupsize: int = groupsize
948949
self.padding_allowed: bool = padding_allowed
949950
self.precision: torch.dtype = precision
950951
self.scales_precision: torch.dtype = scales_precision
952+
self.device: torch.device = device
951953

952954
@torch.no_grad()
953955
def _create_quantized_state_dict(
@@ -988,9 +990,9 @@ def _create_quantized_state_dict(
988990
self.groupsize,
989991
self.scales_precision,
990992
)
991-
cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu")
992-
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
993-
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
993+
cur_state_dict[f"{fqn}.weight"] = weight_int8.to(self.device)
994+
cur_state_dict[f"{fqn}.scales"] = scales.to(self.device)
995+
cur_state_dict[f"{fqn}.zeros"] = zeros.to(self.device)
994996
# TODO: support bias?
995997

996998
return cur_state_dict

0 commit comments

Comments
 (0)