File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -46,13 +46,20 @@ def define_node(
4646 input_zp = cast (int , node .args [3 ])
4747 output_zp = cast (int , node .args [4 ])
4848
49- if input_dtype != map_dtype (torch .int8 , self .tosa_spec ) and input_zp != 0 :
49+ if (
50+ input_dtype
51+ not in [
52+ map_dtype (torch .int8 , self .tosa_spec ),
53+ map_dtype (torch .int16 , self .tosa_spec ),
54+ ]
55+ and input_zp != 0
56+ ):
5057 raise ValueError (
51- f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ input_dtype = } , { input_zp = } "
58+ f"If input dtype is not int8 or int16 , input_zp must be 0. Got input_dtype{ input_dtype = } , { input_zp = } "
5259 )
53- if output_dtype != torch .int8 and output_zp != 0 :
60+ if output_dtype not in [ torch .int8 , torch . int16 ] and output_zp != 0 :
5461 raise ValueError (
55- f"If output dtype is not int8, output_zp must be 0. Got { ts .DTypeNames [output_dtype ]} , { output_zp = } "
62+ f"If output dtype is not int8 or int16 , output_zp must be 0. Got { ts .DTypeNames [output_dtype ]} , { output_zp = } "
5663 )
5764
5865 build_rescale (
You can’t perform that action at this time.
0 commit comments