diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index 0afded38f..14395bd61 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -102,9 +102,6 @@ def from_value(cls, value: Any, prefer_name: str = None) -> Value: return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value) elif isinstance(value, float): return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value) - elif isinstance(value, tir.Var): - # handle A: T.Tensor[[M, N, K], ...] - return Value(kind='dynamic', name=value.name, dtype=value.dtype, value=value) elif isinstance(value, dt.dtype): # handle A: T.float32 return Value(kind='dynamic', name=prefer_name, dtype=value, value=None) @@ -113,6 +110,11 @@ def from_value(cls, value: Any, prefer_name: str = None) -> Value: return value elif isinstance(value, TypeVar): return Value(kind='static', name=value.__name__, value=None) + elif isinstance(value, (tir.Var, PrimExpr)): + # handle A: T.Tensor[[M, N, K], ...] + # or primexpr annotation like A: T.Tensor[[M, N * 4 +1]] + name = value.name if isinstance(value, tir.Var) else prefer_name + return Value(kind='dynamic', name=name, dtype=value.dtype, value=value) elif value is Any or value is None or value is dt.dtype or isinstance( value, (type, _GenericAlias)): # A # no annotation @@ -122,7 +124,7 @@ def from_value(cls, value: Any, prefer_name: str = None) -> Value: # A: tuple[...] return Value(kind='static', name=prefer_name, value=None) else: - raise TypeError(f"Unsupported Value annotation: {value!r}") + raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}") def with_name(self, name: str) -> Value: return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value)