Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 13 additions & 49 deletions tilelang/language/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,58 +226,22 @@ def num_threads(self) -> int:


def Kernel(
*blocks: list[tir.PrimExpr],
*blocks: tir.PrimExpr,
threads: int | list[int] | tuple | None = None,
is_cpu: bool = False,
prelude: str | None = None,
):
"""Tools to quickly construct a GPU kernel launch frame.

Parameters
----------
blocks : List[int]
A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z)
threads : int
A integer representing blockDim.x
Or a list of integers representing blockDim.(x|y|z)
if the value is -1, we skip the threadIdx.x binding.
is_cpu : bool
Whether the kernel is running on CPU.
Thus we will not bind threadIdx.x, threadIdx.y, threadIdx.z.
and blockIdx.x, blockIdx.y, blockIdx.z.
prelude : str
The import c code of the kernel,
will be injected before the generated kernel code.

Returns
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.

Examples
--------
Create a 1-D CUDA kernel launch and unpack the single block index:

.. code-block:: python

with T.Kernel(T.ceildiv(N, 128), threads=128) as bx:
# bx is the blockIdx.x binding (also iterable as (bx,))
...

Launch a 2-D grid while requesting two thread dimensions:

.. code-block:: python

with T.Kernel(grid_x, grid_y, threads=(64, 2)) as (bx, by):
tx, ty = T.get_thread_bindings()
...

Emit a CPU kernel where thread bindings are skipped:

.. code-block:: python

with T.Kernel(loop_extent, is_cpu=True) as (i,):
...
"""
Construct a kernel launch frame for TileLang and return a launch-frame descriptor usable as a context manager.

Parameters:
blocks: One to three tir.PrimExpr values specifying grid dimensions for blockIdx.(x|y|z).
threads: An int, list, or tuple specifying thread block dimensions. When not provided and not a CPU kernel, defaults to 128 for the first dimension and 1 for missing dimensions; values are normalized to a length-three list [x, y, z].
is_cpu: If True, mark the frame as a CPU kernel so thread and block thread-bindings are not created.
prelude: Optional C source to be injected before the generated kernel code via pragma_import_c.

Returns:
A KernelLaunchFrame descriptor (FFI handle) that can be used as a context manager; entering the context yields the block binding Var for a single-dimension grid or a tuple/list of Vars for multiple dimensions.
"""
attrs: dict = {}

Expand Down Expand Up @@ -347,4 +311,4 @@ def get_block_extent(dim: int = 0) -> int:
def get_block_extents() -> list[int]:
"""Returns all three block extents."""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_block_extents()
return KernelLaunchFrame.Current().get_block_extents()
50 changes: 47 additions & 3 deletions tilelang/language/v2/annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,33 @@ def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]
@property
def strides(self) -> tuple[tir.PrimExpr]: ...

def scope(self) -> Scope: ...
def scope(self) -> Scope: """
Get the memory scope identifier used for buffers when emitting TIR.

Returns:
scope (Scope): Memory scope name (e.g. "global", "local", "shared.dyn", "local.fragment") used for the created buffer.
"""
...

def __getitem__(self, idx) -> Buffer: """
Access a sub-buffer or view of this Buffer using indexing.

Parameters:
idx (int | slice | tuple): Index, slice, or tuple of indices that selects the sub-buffer, dimension(s), or region.

Returns:
Buffer: A Buffer representing the selected sub-buffer or view.
"""
...

def __setitem__(self, idx, val): """
Assign a value to the buffer element(s) specified by `idx`.

Parameters:
idx: An index or slice specifying which element(s) of the buffer to assign.
val: The value to assign to the selected element(s); should be a buffer-compatible object.
"""
...

class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
def __new__(
Expand All @@ -596,7 +622,25 @@ def __new__(
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: """
Construct a Tensor annotation with the given shape, dtype, and buffer properties.

Parameters:
shape (tuple): Tensor shape expressed as ints or PrimExprs; a single int is treated as a 1-D shape when provided.
dtype (str | dt.dtype): Element data type for the tensor (defaults to "float32").
data: Optional backing object or constant used to populate the buffer (framework-specific).
strides: Optional tuple of strides; if omitted, contiguous row-major strides are computed.
elem_offset: Optional element offset within the buffer.
scope (str): Memory scope string for the buffer (e.g., "global", "local"); uses the annot's default when None.
align (int): Alignment in bytes for the buffer.
offset_factor (int): Offset factor for buffer addressing.
buffer_type (str): Optional buffer type metadata string.
axis_separators: Optional layout separators for visual/serialization purposes.

Returns:
Tensor: A Tensor annotation describing a buffer with the specified shape, dtype, and layout properties.
"""
...

class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]):
def __new__(
Expand Down Expand Up @@ -711,4 +755,4 @@ def get_compile_time_unknown_args(self):
for name, annot in self.annots.items():
if not isinstance(annot, TIRAnnot):
res.append(name)
return res
return res
Loading