@@ -43,8 +43,7 @@ def get_registered_ref_implementations() -> set[str]:
4343}
4444
4545
46- @impl_tracked (m , "quantize_per_tensor" )
47- def quantize_per_tensor (
46+ def quantize_per_tensor_common (
4847 input_tensor : torch .Tensor ,
4948 scale : float ,
5049 zero_point : int ,
@@ -93,8 +92,68 @@ def quantize_per_tensor(
9392 )
9493
9594
96- @impl_tracked (m , "dequantize_per_tensor" )
97- def dequantize_per_tensor (
95+ def quantize_per_tensor_variant (
96+ dtype : torch .dtype | None = None ,
97+ ) -> Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
98+ """Create a quantize_per_tensor variant with type checking."""
99+
100+ def decorator (_ : Callable [..., torch .Tensor ]) -> Callable [..., torch .Tensor ]:
101+ def variant (
102+ input_tensor : torch .Tensor ,
103+ scale : float ,
104+ zero_point : int ,
105+ quant_min : int ,
106+ quant_max : int ,
107+ out_dtype : torch .dtype ,
108+ ) -> torch .Tensor :
109+ if dtype and out_dtype != dtype :
110+ raise ValueError (f"dtype must be { dtype } . Got { out_dtype } " )
111+
112+ return quantize_per_tensor_common (
113+ input_tensor ,
114+ scale ,
115+ zero_point ,
116+ quant_min ,
117+ quant_max ,
118+ out_dtype ,
119+ )
120+
121+ return variant
122+
123+ return decorator
124+
125+
126+ @impl_tracked (m , "quantize_per_tensor" )
127+ @quantize_per_tensor_variant ()
128+ def quantize_per_tensor () -> torch .Tensor : ...
129+
130+
131+ @impl_tracked (m , "quantize_per_tensor_asym8u" )
132+ @quantize_per_tensor_variant (torch .uint8 )
133+ def quantize_per_tensor_asym8u () -> torch .Tensor : ...
134+
135+
136+ @impl_tracked (m , "quantize_per_tensor_asym8s" )
137+ @quantize_per_tensor_variant (torch .int8 )
138+ def quantize_per_tensor_asym8s () -> torch .Tensor : ...
139+
140+
141+ @impl_tracked (m , "quantize_per_tensor_asym16u" )
142+ @quantize_per_tensor_variant (torch .uint16 )
143+ def quantize_per_tensor_asym16u () -> torch .Tensor : ...
144+
145+
146+ @impl_tracked (m , "quantize_per_tensor_asym16s" )
147+ @quantize_per_tensor_variant (torch .int16 )
148+ def quantize_per_tensor_asym16s () -> torch .Tensor : ...
149+
150+
151+ @impl_tracked (m , "quantize_per_tensor_asym32s" )
152+ @quantize_per_tensor_variant (torch .int32 )
153+ def quantize_per_tensor_asym32s () -> torch .Tensor : ...
154+
155+
156+ def dequantize_per_tensor_common (
98157 input_tensor : torch .Tensor ,
99158 scale : float ,
100159 zero_point : int ,
@@ -133,14 +192,72 @@ def dequantize_per_tensor(
133192 if input_tensor .dtype != dtype :
134193 raise ValueError ("Input dtype must match dtype" )
135194
136- # Use the reference implementation from torch quantized_decomposed library
137- # Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior
138- # difference, since there's no rounding algorithm (just arithmetic).
139195 return torch .ops .quantized_decomposed .dequantize_per_tensor (
140196 input_tensor , scale , zero_point , quant_min , quant_max , dtype
141197 )
142198
143199
200+ def dequantize_per_tensor_variant (
201+ dtype : torch .dtype | None = None ,
202+ ) -> Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
203+ """Create a dequantize_per_tensor variant with type checking."""
204+
205+ def decorator (_ : Callable [..., torch .Tensor ]) -> Callable [..., torch .Tensor ]:
206+ def variant (
207+ input_tensor : torch .Tensor ,
208+ scale : float ,
209+ zero_point : int ,
210+ quant_min : int ,
211+ quant_max : int ,
212+ in_dtype : torch .dtype ,
213+ ) -> torch .Tensor :
214+ if dtype and in_dtype != dtype :
215+ raise ValueError (f"dtype must be { dtype } . Got { in_dtype } " )
216+
217+ return dequantize_per_tensor_common (
218+ input_tensor ,
219+ scale ,
220+ zero_point ,
221+ quant_min ,
222+ quant_max ,
223+ in_dtype ,
224+ )
225+
226+ return variant
227+
228+ return decorator
229+
230+
231+ @impl_tracked (m , "dequantize_per_tensor" )
232+ @dequantize_per_tensor_variant ()
233+ def dequantize_per_tensor () -> torch .Tensor : ...
234+
235+
236+ @impl_tracked (m , "dequantize_per_tensor_asym8u" )
237+ @dequantize_per_tensor_variant (torch .uint8 )
238+ def dequantize_per_tensor_asym8u () -> torch .Tensor : ...
239+
240+
241+ @impl_tracked (m , "dequantize_per_tensor_asym32s" )
242+ @dequantize_per_tensor_variant (torch .int32 )
243+ def dequantize_per_tensor_asym32s () -> torch .Tensor : ...
244+
245+
246+ @impl_tracked (m , "dequantize_per_tensor_asym16u" )
247+ @dequantize_per_tensor_variant (torch .uint16 )
248+ def dequantize_per_tensor_asym16u () -> torch .Tensor : ...
249+
250+
251+ @impl_tracked (m , "dequantize_per_tensor_asym8s" )
252+ @dequantize_per_tensor_variant (torch .int8 )
253+ def dequantize_per_tensor_asym8s () -> torch .Tensor : ...
254+
255+
256+ @impl_tracked (m , "dequantize_per_tensor_asym16s" )
257+ @dequantize_per_tensor_variant (torch .int16 )
258+ def dequantize_per_tensor_asym16s () -> torch .Tensor : ...
259+
260+
144261@impl_tracked (m , "quantized_add.per_tensor" )
145262def quantized_add_per_tensor (
146263 X : torch .Tensor ,
0 commit comments