Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion tilelang/language/gemm_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface."""

from __future__ import annotations

from tilelang.tileop.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
Expand Down Expand Up @@ -184,4 +185,38 @@ def gemm_v2(


# Default to v2; allow forcing v1 via environment variable
gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2
# gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2


def gemm(
A: tir.Buffer | tir.Var,
B: tir.Buffer | tir.Var,
C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
mbar: tir.Buffer | None = None,
):
"""TileLang GEMM operator.

Args:
A (tir.Buffer | tir.Var): Input buffer A.
B (tir.Buffer | tir.Var): Input buffer B.
C (tir.Buffer | tir.Var): Output buffer C.
transpose_A (bool): Whether to transpose A. Defaults to False.
transpose_B (bool): Whether to transpose B. Defaults to False.
policy (GemmWarpPolicy): GEMM warp partition policy.
clear_accum (bool): Whether to clear the accumulator.
k_pack (int): Numbers of packed matrix cores, for ROCm only. Defaults to 1.
wg_wait (int): Int identifier of the warpgroup MMA batch to wait on.. Defaults to 0.
mbar (tir.Buffer | None, optional): Mbarrier in Blackwell. Defaults to None.

Returns:
tir.Call: A handle to the GEMM operation.
"""

impl = gemm_v1 if _env.use_gemm_v1() else gemm_v2
return impl(A, B, C, transpose_A, transpose_B, policy, clear_accum, k_pack, wg_wait, mbar)
48 changes: 44 additions & 4 deletions tilelang/language/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,28 @@ def Pipelined(
def serial(
start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None
) -> frame.ForFrame:
"""The serial For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

step : PrimExpr
The step size of the iteration.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""

step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
Expand Down Expand Up @@ -178,9 +200,27 @@ def unroll(
# "Serial" and "Unroll" are aliases of "T.serial" and "T.unroll". We use uppercase to emphasize that they are tile-level loops.


def Serial(*args, **kwargs):
return serial(*args, **kwargs)
def Serial(
start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None,
):
"""Alias of T.serial."""

return serial(start, stop, step, annotations=annotations)


def Unroll(
start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None,
):
"""Alias of T.unroll."""

def Unroll(*args, **kwargs):
return unroll(*args, **kwargs)
return unroll(start, stop, step, explicit=explicit, unroll_factor=unroll_factor, annotations=annotations)
Loading