diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index e2bda2b94..0b8f3ccf1 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -1,6 +1,7 @@ """GEMM (General Matrix Multiplication) operators exposed on the TileLang language surface.""" from __future__ import annotations + from tilelang.tileop.base import GemmWarpPolicy import tilelang.language as T from tvm import tir @@ -184,4 +185,38 @@ def gemm_v2( # Default to v2; allow forcing v1 via environment variable -gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2 +# gemm = gemm_v1 if _env.use_gemm_v1() else gemm_v2 + + +def gemm( + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, + mbar: tir.Buffer | None = None, +): + """TileLang GEMM operator. + + Args: + A (tir.Buffer | tir.Var): Input buffer A. + B (tir.Buffer | tir.Var): Input buffer B. + C (tir.Buffer | tir.Var): Output buffer C. + transpose_A (bool): Whether to transpose A. Defaults to False. + transpose_B (bool): Whether to transpose B. Defaults to False. + policy (GemmWarpPolicy): GEMM warp partition policy. + clear_accum (bool): Whether to clear the accumulator. + k_pack (int): Numbers of packed matrix cores, for ROCm only. Defaults to 1. + wg_wait (int): Int identifier of the warpgroup MMA batch to wait on.. Defaults to 0. + mbar (tir.Buffer | None, optional): Mbarrier in Blackwell. Defaults to None. + + Returns: + tir.Call: A handle to the GEMM operation. + """ + + impl = gemm_v1 if _env.use_gemm_v1() else gemm_v2 + return impl(A, B, C, transpose_A, transpose_B, policy, clear_accum, k_pack, wg_wait, mbar) diff --git a/tilelang/language/loop.py b/tilelang/language/loop.py index 4fbd4e9f8..45b768095 100644 --- a/tilelang/language/loop.py +++ b/tilelang/language/loop.py @@ -98,6 +98,28 @@ def Pipelined( def serial( start: tir.PrimExpr, stop: tir.PrimExpr | None = None, step: tir.PrimExpr | None = None, *, annotations: dict[str, Any] | None = None ) -> frame.ForFrame: + """The serial For statement. + + Parameters + ---------- + start : PrimExpr + The minimum value of iteration. + + stop : PrimExpr + The maximum value of iteration. + + step : PrimExpr + The step size of the iteration. + + annotations : Dict[str, Any] + The optional annotations of the For statement. + + Returns + ------- + res : frame.ForFrame + The ForFrame. + """ + step_is_one = False step_is_one |= isinstance(step, int) and step == 1 step_is_one |= isinstance(step, IntImm) and step.value == 1 @@ -178,9 +200,27 @@ def unroll( # "Serial" and "Unroll" are aliases of "T.serial" and "T.unroll". We use uppercase to emphasize that they are tile-level loops. -def Serial(*args, **kwargs): - return serial(*args, **kwargs) +def Serial( + start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + annotations: dict[str, Any] | None = None, +): + """Alias of T.serial.""" + + return serial(start, stop, step, annotations=annotations) + +def Unroll( + start: tir.PrimExpr, + stop: tir.PrimExpr | None = None, + step: tir.PrimExpr | None = None, + *, + explicit: bool = False, + unroll_factor: int | None = None, + annotations: dict[str, Any] | None = None, +): + """Alias of T.unroll.""" -def Unroll(*args, **kwargs): - return unroll(*args, **kwargs) + return unroll(start, stop, step, explicit=explicit, unroll_factor=unroll_factor, annotations=annotations)