diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 8036e6acf..73377822b 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -14,7 +14,12 @@ with the appropriate memory scope. """ from __future__ import annotations -from typing import TypeVarTuple, TypeVar, overload, Literal, Unpack, Callable +from typing import TypeVar, overload, Literal, Callable +# Python 3.9 compatibility for advanced typing features (PEP 646) +try: + from typing import TypeVarTuple, Unpack # type: ignore[attr-defined] +except Exception: + from typing_extensions import TypeVarTuple, Unpack # type: ignore from tilelang import tvm as tvm from tvm.script import tir as T from tvm.tir import PrimExpr diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index 14395bd61..b61d9d11c 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -4,7 +4,31 @@ from tvm import tir from tvm.ir.expr import PrimExpr from tvm.script.ir_builder.tir import buffer -from typing import Any, Callable, Literal, TypeVar, ParamSpec, Generic, TypeVarTuple, Unpack, TYPE_CHECKING, _GenericAlias, Self +from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING +# Python 3.9 compatibility for advanced typing features +try: + from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined] +except Exception: # Python < 3.10 for ParamSpec, < 3.11 for Unpack/TypeVarTuple/Self + from typing_extensions import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore + +# Compatibility for generic alias detection across Python versions +try: + from typing import _GenericAlias as _TypingGenericAlias # type: ignore[attr-defined] +except Exception: + _TypingGenericAlias = None # type: ignore +try: + # Builtin generic alias type for e.g. tuple[int] + from types import GenericAlias as _TypesGenericAlias # type: ignore[attr-defined] +except Exception: + _TypesGenericAlias = None # type: ignore + +_GenericAliasTypes = tuple(t for t in (_TypingGenericAlias, _TypesGenericAlias) if t is not None) +if not _GenericAliasTypes: + + class _DummyGenericAlias: # type: ignore + pass + + _GenericAliasTypes = (_DummyGenericAlias,) # type: ignore from collections.abc import Sequence from .dtypes import AnyDType from . import dtypes as dt @@ -116,7 +140,7 @@ def from_value(cls, value: Any, prefer_name: str = None) -> Value: name = value.name if isinstance(value, tir.Var) else prefer_name return Value(kind='dynamic', name=name, dtype=value.dtype, value=value) elif value is Any or value is None or value is dt.dtype or isinstance( - value, (type, _GenericAlias)): + value, (type,) + _GenericAliasTypes): # A # no annotation # A: Any # A: _T @@ -358,17 +382,6 @@ def promote(self): buf = buffer(shape, self.dtype, strides=strides, scope=self.scope) return TIRAnnot(data=buf) - # def __repr__(self): - # items = [] - # if self.shape is not None: - # items.append(f'shape=[{', '.join(map(repr, self.shape))}]') - # if self.strides is not None: - # items.append(f'strides=[{', '.join(map(repr, self.strides))}]') - # if self.dtype is not None: - # items.append(f'dtype={self.dtype}') - # items.append(f'scope={repr(self.scope)}') - # return 'Buffer(' + ', '.join(items) + ')' - class TensorAnnot(BufferAnnot):