@@ -43,8 +43,7 @@ def get_registered_ref_implementations() -> set[str]:
4343}
4444
4545
46- @impl_tracked (m , "quantize_per_tensor" ) 
47- def  quantize_per_tensor (
46+ def  quantize_per_tensor_common (
4847    input_tensor : torch .Tensor ,
4948    scale : float ,
5049    zero_point : int ,
@@ -93,8 +92,68 @@ def quantize_per_tensor(
9392    )
9493
9594
96- @impl_tracked (m , "dequantize_per_tensor" ) 
97- def  dequantize_per_tensor (
95+ def  quantize_per_tensor_variant (
96+     dtype : torch .dtype  |  None  =  None ,
97+ ) ->  Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
98+     """Create a quantize_per_tensor variant with type checking.""" 
99+ 
100+     def  decorator (_ : Callable [..., torch .Tensor ]) ->  Callable [..., torch .Tensor ]:
101+         def  variant (
102+             input_tensor : torch .Tensor ,
103+             scale : float ,
104+             zero_point : int ,
105+             quant_min : int ,
106+             quant_max : int ,
107+             out_dtype : torch .dtype ,
108+         ) ->  torch .Tensor :
109+             if  dtype  and  out_dtype  !=  dtype :
110+                 raise  ValueError (f"dtype must be { dtype } { out_dtype }  )
111+ 
112+             return  quantize_per_tensor_common (
113+                 input_tensor ,
114+                 scale ,
115+                 zero_point ,
116+                 quant_min ,
117+                 quant_max ,
118+                 out_dtype ,
119+             )
120+ 
121+         return  variant 
122+ 
123+     return  decorator 
124+ 
125+ 
126+ @impl_tracked (m , "quantize_per_tensor" ) 
127+ @quantize_per_tensor_variant () 
128+ def  quantize_per_tensor () ->  torch .Tensor : ...
129+ 
130+ 
131+ @impl_tracked (m , "quantize_per_tensor_asym8u" ) 
132+ @quantize_per_tensor_variant (torch .uint8 ) 
133+ def  quantize_per_tensor_asym8u () ->  torch .Tensor : ...
134+ 
135+ 
136+ @impl_tracked (m , "quantize_per_tensor_asym8s" ) 
137+ @quantize_per_tensor_variant (torch .int8 ) 
138+ def  quantize_per_tensor_asym8s () ->  torch .Tensor : ...
139+ 
140+ 
141+ @impl_tracked (m , "quantize_per_tensor_asym16u" ) 
142+ @quantize_per_tensor_variant (torch .uint16 ) 
143+ def  quantize_per_tensor_asym16u () ->  torch .Tensor : ...
144+ 
145+ 
146+ @impl_tracked (m , "quantize_per_tensor_asym16s" ) 
147+ @quantize_per_tensor_variant (torch .int16 ) 
148+ def  quantize_per_tensor_asym16s () ->  torch .Tensor : ...
149+ 
150+ 
151+ @impl_tracked (m , "quantize_per_tensor_asym32s" ) 
152+ @quantize_per_tensor_variant (torch .int32 ) 
153+ def  quantize_per_tensor_asym32s () ->  torch .Tensor : ...
154+ 
155+ 
156+ def  dequantize_per_tensor_common (
98157    input_tensor : torch .Tensor ,
99158    scale : float ,
100159    zero_point : int ,
@@ -133,14 +192,72 @@ def dequantize_per_tensor(
133192    if  input_tensor .dtype  !=  dtype :
134193        raise  ValueError ("Input dtype must match dtype" )
135194
136-     # Use the reference implementation from torch quantized_decomposed library 
137-     # Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior 
138-     # difference, since there's no rounding algorithm (just arithmetic). 
139195    return  torch .ops .quantized_decomposed .dequantize_per_tensor (
140196        input_tensor , scale , zero_point , quant_min , quant_max , dtype 
141197    )
142198
143199
200+ def  dequantize_per_tensor_variant (
201+     dtype : torch .dtype  |  None  =  None ,
202+ ) ->  Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
203+     """Create a dequantize_per_tensor variant with type checking.""" 
204+ 
205+     def  decorator (_ : Callable [..., torch .Tensor ]) ->  Callable [..., torch .Tensor ]:
206+         def  variant (
207+             input_tensor : torch .Tensor ,
208+             scale : float ,
209+             zero_point : int ,
210+             quant_min : int ,
211+             quant_max : int ,
212+             in_dtype : torch .dtype ,
213+         ) ->  torch .Tensor :
214+             if  dtype  and  in_dtype  !=  dtype :
215+                 raise  ValueError (f"dtype must be { dtype } { in_dtype }  )
216+ 
217+             return  dequantize_per_tensor_common (
218+                 input_tensor ,
219+                 scale ,
220+                 zero_point ,
221+                 quant_min ,
222+                 quant_max ,
223+                 in_dtype ,
224+             )
225+ 
226+         return  variant 
227+ 
228+     return  decorator 
229+ 
230+ 
231+ @impl_tracked (m , "dequantize_per_tensor" ) 
232+ @dequantize_per_tensor_variant () 
233+ def  dequantize_per_tensor () ->  torch .Tensor : ...
234+ 
235+ 
236+ @impl_tracked (m , "dequantize_per_tensor_asym8u" ) 
237+ @dequantize_per_tensor_variant (torch .uint8 ) 
238+ def  dequantize_per_tensor_asym8u () ->  torch .Tensor : ...
239+ 
240+ 
241+ @impl_tracked (m , "dequantize_per_tensor_asym32s" ) 
242+ @dequantize_per_tensor_variant (torch .int32 ) 
243+ def  dequantize_per_tensor_asym32s () ->  torch .Tensor : ...
244+ 
245+ 
246+ @impl_tracked (m , "dequantize_per_tensor_asym16u" ) 
247+ @dequantize_per_tensor_variant (torch .uint16 ) 
248+ def  dequantize_per_tensor_asym16u () ->  torch .Tensor : ...
249+ 
250+ 
251+ @impl_tracked (m , "dequantize_per_tensor_asym8s" ) 
252+ @dequantize_per_tensor_variant (torch .int8 ) 
253+ def  dequantize_per_tensor_asym8s () ->  torch .Tensor : ...
254+ 
255+ 
256+ @impl_tracked (m , "dequantize_per_tensor_asym16s" ) 
257+ @dequantize_per_tensor_variant (torch .int16 ) 
258+ def  dequantize_per_tensor_asym16s () ->  torch .Tensor : ...
259+ 
260+ 
144261@impl_tracked (m , "quantized_add.per_tensor" ) 
145262def  quantized_add_per_tensor (
146263    X : torch .Tensor ,
0 commit comments