diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 17995bfa7850..f7d40dc68147 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass infrastructure across IR variants.""" -import types import inspect import functools @@ -340,7 +339,7 @@ def create_module_pass(pass_arg): info = PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_transform_api.MakeModulePass(pass_arg, info) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 1f5b91da4432..4c609620cbb7 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1049,7 +1049,7 @@ def create_function_pass(pass_arg): info = tvm.transform.PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_api.MakeFunctionPass(pass_arg, info) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index ec103ac18811..bd47e416305f 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -20,7 +20,6 @@ import operator import logging import sys -import types import numbers from enum import Enum @@ -142,7 +141,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None): self.symbols = {} # Symbol table for k, v in symbols.items(): - if isinstance(v, types.FunctionType): + if callable(v): self.add_symbol(k, Symbol.Callable, v) self.closure_vars = closure_vars diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9450ade34e67..9fa0e3bc181f 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -16,7 +16,6 @@ # under the License. """TIR specific function pass support.""" import inspect -import types import functools from typing import Callable, List, Optional, Union @@ -151,7 +150,7 @@ def create_function_pass(pass_arg): info = PassInfo(opt_level, fname, required) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) - if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + if not callable(pass_arg): raise TypeError("pass_func must be a callable for Module pass") return _ffi_api.CreatePrimFuncPass(pass_arg, info) # type: ignore