Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tilelang/language/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def num_threads(self) -> int:


def Kernel(
*blocks: list[tir.PrimExpr],
*blocks: tir.PrimExpr,
threads: int | list[int] | tuple | None = None,
is_cpu: bool = False,
prelude: str | None = None,
Expand All @@ -235,7 +235,7 @@ def Kernel(

Parameters
----------
blocks : List[int]
blocks : int
A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z)
threads : int
A integer representing blockDim.x
Expand Down
3 changes: 3 additions & 0 deletions tilelang/language/tir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm

_T = TypeVar("_T")

def Cast(dtype, value: _T, span: Span | None = None) -> _T: ...
def abs(x: _T, span: Span | None = None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
Expand Down Expand Up @@ -44,7 +45,9 @@ def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None = None) -> PrimExpr: ...
def max(x: _T, y: _T, span: Span | None = None) -> _T: ...
def max_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def min(x: _T, y: _T, span: Span | None = None) -> _T: ...
def min_value(dtype: str, span: Span | None = None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None = None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
Expand Down
4 changes: 4 additions & 0 deletions tilelang/language/v2/annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,10 @@ def strides(self) -> tuple[tir.PrimExpr]: ...

def scope(self) -> Scope: ...

def __getitem__(self, idx) -> Buffer: ...

def __setitem__(self, idx, val): ...

class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
def __new__(
shape: tuple[Unpack[_Shapes]],
Expand Down
Loading