diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 676c953ee..70dc1690b 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -2,6 +2,7 @@ from contextlib import contextmanager, AbstractContextManager from dataclasses import dataclass import inspect +import sys from tilelang.language.kernel import KernelLaunchFrame from tvm_ffi.container import Map @@ -815,7 +816,10 @@ def get_type_hints(func): continue except Exception: pass - value = ForwardRef(value, is_argument=True, is_class=False) + if sys.version_info >= (3, 10): + value = ForwardRef(value, module=func.__module__) + else: + value = ForwardRef(value, is_argument=True) hints[name] = _eval_type(value, globalns=globalns, localns=localns) else: hints[name] = value