diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 875ec3b83ea8..be17001fd034 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -511,6 +511,18 @@ def from_exported_program( ): output = None with self.block_builder.dataflow(): + + # Find all the missing function types + missing_func_types = list( + { + node.target.__name__ + for node in nodes + if node.op == "call_function" + and node.target.__name__ not in self.convert_map + } + ) + assert not missing_func_types, f"Unsupported function types {missing_func_types}" + # Translate the model. for node in nodes: if node.op == "placeholder": @@ -537,9 +549,6 @@ def from_exported_program( self.env[node] = getattr(exported_program.graph_module, node.target) elif node.op == "call_function": func_name = node.target.__name__ - assert ( - func_name in self.convert_map - ), f"Unsupported function type {func_name}" self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a5b50a7d1dce..f6dd235d5a23 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -884,6 +884,18 @@ def from_fx( with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None with self.block_builder.dataflow(): + + # Find all the missing function types + missing_func_types = list( + { + node.target.__name__ + for node in graph.nodes + if node.op == "call_function" + and node.target.__name__ not in self.convert_map + } + ) + assert not missing_func_types, f"Unsupported function types {missing_func_types}" + # Translate model parameters. for _, param in model.named_parameters(): shape = param.data.shape @@ -929,9 +941,6 @@ def from_fx( self.env[node] = self.convert_map[type(module)](node) elif node.op == "call_function": func_name = node.target.__name__ - assert ( - func_name in self.convert_map - ), f"Unsupported function type {func_name}" if func_name in custom_ops: self.env[node] = self.convert_map[func_name](node, self) else: