Skip to content

Commit 6569ebd

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
All type-specific quantize/dequantize (#15165)
Summary: As titled. Reviewed By: skrtskrtfb Differential Revision: D84675269
1 parent 06ea3d6 commit 6569ebd

File tree

2 files changed

+124
-17
lines changed

2 files changed

+124
-17
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,10 @@ def _validate_ref_impl_exists() -> None:
5858
"cadence::_softmax_f32_f32",
5959
"cadence::requantize", # We should only support per_tensor variant, should remove
6060
"cadence::quantized_softmax.per_tensor",
61-
"cadence::quantize_per_tensor_asym8u",
62-
"cadence::quantize_per_tensor_asym8s",
63-
"cadence::dequantize_per_tensor_asym8u",
64-
"cadence::dequantize_per_tensor_asym32s",
65-
"cadence::dequantize_per_tensor_asym16u",
6661
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
67-
"cadence::quantize_per_tensor_asym32s",
6862
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
6963
"cadence::linalg_svd",
7064
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
71-
"cadence::quantize_per_tensor_asym16u",
72-
"cadence::dequantize_per_tensor_asym8s",
73-
"cadence::quantize_per_tensor_asym16s",
74-
"cadence::dequantize_per_tensor_asym16s",
7565
"cadence::quantized_softmax",
7666
"cadence::quantized_w8a32_gru",
7767
"cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove

backends/cadence/aot/ref_implementations.py

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def get_registered_ref_implementations() -> set[str]:
4343
}
4444

4545

46-
@impl_tracked(m, "quantize_per_tensor")
47-
def quantize_per_tensor(
46+
def quantize_per_tensor_common(
4847
input_tensor: torch.Tensor,
4948
scale: float,
5049
zero_point: int,
@@ -93,8 +92,68 @@ def quantize_per_tensor(
9392
)
9493

9594

96-
@impl_tracked(m, "dequantize_per_tensor")
97-
def dequantize_per_tensor(
95+
def quantize_per_tensor_variant(
96+
dtype: torch.dtype | None = None,
97+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
98+
"""Create a quantize_per_tensor variant with type checking."""
99+
100+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
101+
def variant(
102+
input_tensor: torch.Tensor,
103+
scale: float,
104+
zero_point: int,
105+
quant_min: int,
106+
quant_max: int,
107+
out_dtype: torch.dtype,
108+
) -> torch.Tensor:
109+
if dtype and out_dtype != dtype:
110+
raise ValueError(f"dtype must be {dtype}. Got {out_dtype}")
111+
112+
return quantize_per_tensor_common(
113+
input_tensor,
114+
scale,
115+
zero_point,
116+
quant_min,
117+
quant_max,
118+
out_dtype,
119+
)
120+
121+
return variant
122+
123+
return decorator
124+
125+
126+
@impl_tracked(m, "quantize_per_tensor")
127+
@quantize_per_tensor_variant()
128+
def quantize_per_tensor() -> torch.Tensor: ...
129+
130+
131+
@impl_tracked(m, "quantize_per_tensor_asym8u")
132+
@quantize_per_tensor_variant(torch.uint8)
133+
def quantize_per_tensor_asym8u() -> torch.Tensor: ...
134+
135+
136+
@impl_tracked(m, "quantize_per_tensor_asym8s")
137+
@quantize_per_tensor_variant(torch.int8)
138+
def quantize_per_tensor_asym8s() -> torch.Tensor: ...
139+
140+
141+
@impl_tracked(m, "quantize_per_tensor_asym16u")
142+
@quantize_per_tensor_variant(torch.uint16)
143+
def quantize_per_tensor_asym16u() -> torch.Tensor: ...
144+
145+
146+
@impl_tracked(m, "quantize_per_tensor_asym16s")
147+
@quantize_per_tensor_variant(torch.int16)
148+
def quantize_per_tensor_asym16s() -> torch.Tensor: ...
149+
150+
151+
@impl_tracked(m, "quantize_per_tensor_asym32s")
152+
@quantize_per_tensor_variant(torch.int32)
153+
def quantize_per_tensor_asym32s() -> torch.Tensor: ...
154+
155+
156+
def dequantize_per_tensor_common(
98157
input_tensor: torch.Tensor,
99158
scale: float,
100159
zero_point: int,
@@ -133,14 +192,72 @@ def dequantize_per_tensor(
133192
if input_tensor.dtype != dtype:
134193
raise ValueError("Input dtype must match dtype")
135194

136-
# Use the reference implementation from torch quantized_decomposed library
137-
# Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior
138-
# difference, since there's no rounding algorithm (just arithmetic).
139195
return torch.ops.quantized_decomposed.dequantize_per_tensor(
140196
input_tensor, scale, zero_point, quant_min, quant_max, dtype
141197
)
142198

143199

200+
def dequantize_per_tensor_variant(
201+
dtype: torch.dtype | None = None,
202+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
203+
"""Create a dequantize_per_tensor variant with type checking."""
204+
205+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
206+
def variant(
207+
input_tensor: torch.Tensor,
208+
scale: float,
209+
zero_point: int,
210+
quant_min: int,
211+
quant_max: int,
212+
in_dtype: torch.dtype,
213+
) -> torch.Tensor:
214+
if dtype and in_dtype != dtype:
215+
raise ValueError(f"dtype must be {dtype}. Got {in_dtype}")
216+
217+
return dequantize_per_tensor_common(
218+
input_tensor,
219+
scale,
220+
zero_point,
221+
quant_min,
222+
quant_max,
223+
in_dtype,
224+
)
225+
226+
return variant
227+
228+
return decorator
229+
230+
231+
@impl_tracked(m, "dequantize_per_tensor")
232+
@dequantize_per_tensor_variant()
233+
def dequantize_per_tensor() -> torch.Tensor: ...
234+
235+
236+
@impl_tracked(m, "dequantize_per_tensor_asym8u")
237+
@dequantize_per_tensor_variant(torch.uint8)
238+
def dequantize_per_tensor_asym8u() -> torch.Tensor: ...
239+
240+
241+
@impl_tracked(m, "dequantize_per_tensor_asym32s")
242+
@dequantize_per_tensor_variant(torch.int32)
243+
def dequantize_per_tensor_asym32s() -> torch.Tensor: ...
244+
245+
246+
@impl_tracked(m, "dequantize_per_tensor_asym16u")
247+
@dequantize_per_tensor_variant(torch.uint16)
248+
def dequantize_per_tensor_asym16u() -> torch.Tensor: ...
249+
250+
251+
@impl_tracked(m, "dequantize_per_tensor_asym8s")
252+
@dequantize_per_tensor_variant(torch.int8)
253+
def dequantize_per_tensor_asym8s() -> torch.Tensor: ...
254+
255+
256+
@impl_tracked(m, "dequantize_per_tensor_asym16s")
257+
@dequantize_per_tensor_variant(torch.int16)
258+
def dequantize_per_tensor_asym16s() -> torch.Tensor: ...
259+
260+
144261
@impl_tracked(m, "quantized_add.per_tensor")
145262
def quantized_add_per_tensor(
146263
X: torch.Tensor,

0 commit comments

Comments
 (0)