diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 20018610fce..6806e90ac34 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -144,9 +144,10 @@ def check_common_constraints( return True def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): - # Check inputs are valid dtypes + # Check inputs are valid and have the same dtypes # Gather all args which are nodes args_to_check = [] + reference_dtype = None for arg in node.args: if isinstance(arg, list) or isinstance(arg, tuple): for item in arg: @@ -174,11 +175,19 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): if arg_val.dtype not in valid_dtypes: return False + # Check for mixed dtypes + if reference_dtype is None: + reference_dtype = arg_val.dtype + elif arg_val.dtype != reference_dtype: + return False + return True def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): - # Check outputs are valid dtype + # Check outputs are valid and have the same dtypes node_val = node.meta.get("val", None) + reference_dtype = None + if node_val is None: return True @@ -192,6 +201,12 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): if val.dtype not in valid_dtypes: return False + # Check for mixed dtypes + if reference_dtype is None: + reference_dtype = val.dtype + elif val.dtype != reference_dtype: + return False + return True def _check_node_has_valid_dtype(self, node):