Skip to content

Commit 49a7a89

Browse files
committed
update int8 quant api
1 parent 7006cae commit 49a7a89

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def from_hp(
141141
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
142142
"""Dequantize int8 tensor to floating point"""
143143

144+
if output_dtype is None:
145+
output_dtype = self.dtype
146+
144147
qdata_fp = self.qdata.to(output_dtype)
145148
# Reshape scale to broadcast if granularity is block-wise
146149
scale_expanded = _maybe_expand_scale_to_tensor_shape(
@@ -153,12 +156,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
153156
implements_torch_function = Int8Tensor.implements_torch_function
154157

155158

156-
@implements([aten.dequantize.self])
157-
def _(func, types, args, kwargs):
158-
"""dequantization: int8 -> float"""
159-
return args[0].dequantize()
160-
161-
162159
@implements(aten.linear.default)
163160
@implements_torch_function(torch.nn.functional.linear)
164161
def _(func, types, args, kwargs):

0 commit comments

Comments
 (0)