diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 73f7ed949..16a1d3d10 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -226,58 +226,22 @@ def num_threads(self) -> int: def Kernel( - *blocks: list[tir.PrimExpr], + *blocks: tir.PrimExpr, threads: int | list[int] | tuple | None = None, is_cpu: bool = False, prelude: str | None = None, ): - """Tools to quickly construct a GPU kernel launch frame. - - Parameters - ---------- - blocks : List[int] - A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z) - threads : int - A integer representing blockDim.x - Or a list of integers representing blockDim.(x|y|z) - if the value is -1, we skip the threadIdx.x binding. - is_cpu : bool - Whether the kernel is running on CPU. - Thus we will not bind threadIdx.x, threadIdx.y, threadIdx.z. - and blockIdx.x, blockIdx.y, blockIdx.z. - prelude : str - The import c code of the kernel, - will be injected before the generated kernel code. - - Returns - ------- - res : Tuple[frame.LaunchThreadFrame] - The result LaunchThreadFrame. - - Examples - -------- - Create a 1-D CUDA kernel launch and unpack the single block index: - - .. code-block:: python - - with T.Kernel(T.ceildiv(N, 128), threads=128) as bx: - # bx is the blockIdx.x binding (also iterable as (bx,)) - ... - - Launch a 2-D grid while requesting two thread dimensions: - - .. code-block:: python - - with T.Kernel(grid_x, grid_y, threads=(64, 2)) as (bx, by): - tx, ty = T.get_thread_bindings() - ... - - Emit a CPU kernel where thread bindings are skipped: - - .. code-block:: python - - with T.Kernel(loop_extent, is_cpu=True) as (i,): - ... + """ + Construct a kernel launch frame for TileLang and return a launch-frame descriptor usable as a context manager. + + Parameters: + blocks: One to three tir.PrimExpr values specifying grid dimensions for blockIdx.(x|y|z). + threads: An int, list, or tuple specifying thread block dimensions. When not provided and not a CPU kernel, defaults to 128 for the first dimension and 1 for missing dimensions; values are normalized to a length-three list [x, y, z]. + is_cpu: If True, mark the frame as a CPU kernel so thread and block thread-bindings are not created. + prelude: Optional C source to be injected before the generated kernel code via pragma_import_c. + + Returns: + A KernelLaunchFrame descriptor (FFI handle) that can be used as a context manager; entering the context yields the block binding Var for a single-dimension grid or a tuple/list of Vars for multiple dimensions. """ attrs: dict = {} @@ -347,4 +311,4 @@ def get_block_extent(dim: int = 0) -> int: def get_block_extents() -> list[int]: """Returns all three block extents.""" assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" - return KernelLaunchFrame.Current().get_block_extents() + return KernelLaunchFrame.Current().get_block_extents() \ No newline at end of file diff --git a/tilelang/language/v2/annot.py b/tilelang/language/v2/annot.py index bac92142c..4cfcd0ec7 100644 --- a/tilelang/language/v2/annot.py +++ b/tilelang/language/v2/annot.py @@ -582,7 +582,33 @@ def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType] @property def strides(self) -> tuple[tir.PrimExpr]: ... - def scope(self) -> Scope: ... + def scope(self) -> Scope: """ +Get the memory scope identifier used for buffers when emitting TIR. + +Returns: + scope (Scope): Memory scope name (e.g. "global", "local", "shared.dyn", "local.fragment") used for the created buffer. +""" +... + + def __getitem__(self, idx) -> Buffer: """ +Access a sub-buffer or view of this Buffer using indexing. + +Parameters: + idx (int | slice | tuple): Index, slice, or tuple of indices that selects the sub-buffer, dimension(s), or region. + +Returns: + Buffer: A Buffer representing the selected sub-buffer or view. +""" +... + + def __setitem__(self, idx, val): """ +Assign a value to the buffer element(s) specified by `idx`. + +Parameters: + idx: An index or slice specifying which element(s) of the buffer to assign. + val: The value to assign to the selected element(s); should be a buffer-compatible object. +""" +... class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]): def __new__( @@ -596,7 +622,25 @@ def __new__( offset_factor=0, buffer_type="", axis_separators=None, - ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... + ) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: """ + Construct a Tensor annotation with the given shape, dtype, and buffer properties. + + Parameters: + shape (tuple): Tensor shape expressed as ints or PrimExprs; a single int is treated as a 1-D shape when provided. + dtype (str | dt.dtype): Element data type for the tensor (defaults to "float32"). + data: Optional backing object or constant used to populate the buffer (framework-specific). + strides: Optional tuple of strides; if omitted, contiguous row-major strides are computed. + elem_offset: Optional element offset within the buffer. + scope (str): Memory scope string for the buffer (e.g., "global", "local"); uses the annot's default when None. + align (int): Alignment in bytes for the buffer. + offset_factor (int): Offset factor for buffer addressing. + buffer_type (str): Optional buffer type metadata string. + axis_separators: Optional layout separators for visual/serialization purposes. + + Returns: + Tensor: A Tensor annotation describing a buffer with the specified shape, dtype, and layout properties. + """ + ... class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]): def __new__( @@ -711,4 +755,4 @@ def get_compile_time_unknown_args(self): for name, annot in self.annots.items(): if not isinstance(annot, TIRAnnot): res.append(name) - return res + return res \ No newline at end of file