Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
with the appropriate memory scope.
"""
from __future__ import annotations
from typing import TypeVarTuple, TypeVar, overload, Literal, Unpack, Callable
from typing import TypeVar, overload, Literal, Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
try:
from typing import TypeVarTuple, Unpack # type: ignore[attr-defined]
except Exception:
from typing_extensions import TypeVarTuple, Unpack # type: ignore
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
Expand Down
39 changes: 26 additions & 13 deletions tilelang/language/v2/annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@
from tvm import tir
from tvm.ir.expr import PrimExpr
from tvm.script.ir_builder.tir import buffer
from typing import Any, Callable, Literal, TypeVar, ParamSpec, Generic, TypeVarTuple, Unpack, TYPE_CHECKING, _GenericAlias, Self
from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING
# Python 3.9 compatibility for advanced typing features
try:
from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined]
except Exception: # Python < 3.10 for ParamSpec, < 3.11 for Unpack/TypeVarTuple/Self
from typing_extensions import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore

# Compatibility for generic alias detection across Python versions
try:
from typing import _GenericAlias as _TypingGenericAlias # type: ignore[attr-defined]
except Exception:
_TypingGenericAlias = None # type: ignore
try:
# Builtin generic alias type for e.g. tuple[int]
from types import GenericAlias as _TypesGenericAlias # type: ignore[attr-defined]
except Exception:
_TypesGenericAlias = None # type: ignore

_GenericAliasTypes = tuple(t for t in (_TypingGenericAlias, _TypesGenericAlias) if t is not None)
if not _GenericAliasTypes:

class _DummyGenericAlias: # type: ignore
pass

_GenericAliasTypes = (_DummyGenericAlias,) # type: ignore
from collections.abc import Sequence
from .dtypes import AnyDType
from . import dtypes as dt
Expand Down Expand Up @@ -116,7 +140,7 @@ def from_value(cls, value: Any, prefer_name: str = None) -> Value:
name = value.name if isinstance(value, tir.Var) else prefer_name
return Value(kind='dynamic', name=name, dtype=value.dtype, value=value)
elif value is Any or value is None or value is dt.dtype or isinstance(
value, (type, _GenericAlias)):
value, (type,) + _GenericAliasTypes):
# A # no annotation
# A: Any
# A: _T
Expand Down Expand Up @@ -358,17 +382,6 @@ def promote(self):
buf = buffer(shape, self.dtype, strides=strides, scope=self.scope)
return TIRAnnot(data=buf)

# def __repr__(self):
# items = []
# if self.shape is not None:
# items.append(f'shape=[{', '.join(map(repr, self.shape))}]')
# if self.strides is not None:
# items.append(f'strides=[{', '.join(map(repr, self.strides))}]')
# if self.dtype is not None:
# items.append(f'dtype={self.dtype}')
# items.append(f'scope={repr(self.scope)}')
# return 'Buffer(' + ', '.join(items) + ')'


class TensorAnnot(BufferAnnot):

Expand Down
Loading