File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -37,18 +37,24 @@ def define_node(
3737 inputs : List [TosaArg ],
3838 output : TosaArg ,
3939 ) -> None :
40+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
41+ raise TypeError (
42+ f"All IO needs to have the same data type, got: "
43+ f"{ inputs [0 ].dtype = } , { inputs [1 ].dtype = } and { output .dtype = } "
44+ )
4045
41- assert inputs [0 ].dtype == inputs [1 ].dtype , "Both inputs must be of same type"
42- assert inputs [0 ].dtype in [
43- ts .DType .INT8 ,
44- ts .DType .FP32 ,
45- ], "Only int8 and float32 supported"
46- # aten.bmm maps directly to MATMUL
4746 # NOTE: For now, only INT8 & FP32 is supported
47+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
48+ for input in inputs :
49+ if input .dtype not in supported_dtypes :
50+ raise TypeError (
51+ f'IO data type needs to be { supported_dtypes } , got "{ input .dtype } "'
52+ )
53+
54+ # aten.bmm maps directly to MATMUL
4855
4956 # For INT8, we need to get the zero points and add an intermediate tensor
5057 # for a later rescale.
51-
5258 if inputs [0 ].dtype == ts .DType .INT8 :
5359 input_qparams = get_input_qparams (node )
5460 input0_zp = input_qparams [0 ].zp
You can’t perform that action at this time.
0 commit comments