diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 73f7ed949..8679971e4 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -226,7 +226,7 @@ def num_threads(self) -> int: def Kernel( - *blocks: list[tir.PrimExpr], + *blocks: tir.PrimExpr, threads: int | list[int] | tuple | None = None, is_cpu: bool = False, prelude: str | None = None, @@ -235,7 +235,7 @@ def Kernel( Parameters ---------- - blocks : List[int] + blocks : int A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z) threads : int A integer representing blockDim.x diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi index 7723f1378..76199cc3b 100644 --- a/tilelang/language/tir/ir.pyi +++ b/tilelang/language/tir/ir.pyi @@ -3,6 +3,7 @@ from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm _T = TypeVar("_T") +def Cast(dtype, value: _T, span: Span | None = None) -> _T: ... def abs(x: _T, span: Span | None = None) -> _T: ... def acos(x: _T) -> _T: ... def acosh(x: _T) -> _T: ... @@ -44,7 +45,9 @@ def log1p(x: _T) -> _T: ... def log2(x: _T) -> _T: ... def log10(x: _T) -> _T: ... def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ... +def max(x: _T, y: _T, span: Span | None = None) -> _T: ... def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ... +def min(x: _T, y: _T, span: Span | None = None) -> _T: ... def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ... def nearbyint(x: _T, span: Span | None = None) -> _T: ... def nextafter(x1: _T, x2: _T) -> _T: ... diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index bac92142c..0595818b5 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -584,6 +584,10 @@ def strides(self) -> tuple[tir.PrimExpr]: ... def scope(self) -> Scope: ... + def __getitem__(self, idx) -> Buffer: ... + + def __setitem__(self, idx, val): ... + class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]): def __new__( shape: tuple[Unpack[_Shapes]],