Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions testing/python/issue/test_tilelang_issue_1690.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import tilelang
import tilelang.testing
import tilelang.language as T


def test_issue_1690():
@tilelang.jit()
def test(A):
N = T.const("N")
A: T.Tensor[[N], T.float32]
with T.Kernel():
tmp = T.alloc_fragment((N,), T.float32)
tmp_max = T.alloc_fragment(1, T.float32)
T.copy(A, tmp)
T.reduce_max(tmp, tmp_max, dim=0)

test.compile(N=16)


if __name__ == "__main__":
tilelang.testing.main()
7 changes: 0 additions & 7 deletions testing/python/language/test_tilelang_language_eager_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,6 @@ def test_jit2_return():
def copy_impl(A):
M, N = A.shape
B = T.empty(M, N, dtype=A.dtype)
M, N = A.shape
M_, N_ = B.shape
assert M == M_, f"M mismatch {M} {M_}"
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128], B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128])
return B
Expand Down Expand Up @@ -212,8 +207,6 @@ def copy6(A):
A: T.StridedTensor[[N, M], [N_, M_], T.float32]
return copy_impl(A)

tilelang.par_compile([copy.get_tir(T.Tensor((128, 128))) for copy in [copy1, copy2, copy3, copy4]])

for copy in [copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device="cuda")
B = copy(A)
Expand Down
3 changes: 2 additions & 1 deletion tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy"
# auto: infer by checking if function returns PrimFunc directly
if not isinstance(self.func, JITFunc):
return "lazy"
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
is_lazy_style = self.func._is_lazy_style(*args, **kwargs)
return "lazy" if is_lazy_style else "eager"

def initialize_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
if self.mode == "auto":
Expand Down
7 changes: 4 additions & 3 deletions tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from . import dtypes as _dtypes
from .dtypes import dtype as tl_dtype
from .eager.builder import OutTensor
from .proxy import Tensor

_Shapes = TypeVarTuple("_Shapes")
_DType = TypeVar("_DType")
Expand Down Expand Up @@ -264,15 +265,15 @@ def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32):


@overload
def empty(shape, dtype: str = _dtypes.float32): ...
def empty(shape, dtype: str = _dtypes.float32) -> Tensor: ...


def empty(*shape, dtype: str = _dtypes.float32):
def empty(*shape, dtype: str = _dtypes.float32) -> Tensor:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
return OutTensor(shape[0], shape[1])
elif all([isinstance(x, (int, PrimExpr)) for x in shape]):
return OutTensor(shape, dtype)
else:
raise RuntimeError(f"Invalid shape {shape}")
raise TypeError(f"Invalid shape {shape}")
11 changes: 11 additions & 0 deletions tilelang/language/eager/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,20 @@ def visit_Return(self, node: ast.Return):
return quote("return __tb.ret(value)", value=node.value, span=node)

def visit_With(self, node: ast.With):
is_kernel_ctx = False
for expr in node.items:
cexpr = expr.context_expr
if isinstance(cexpr, ast.Call) and isinstance(cexpr.func, ast.Attribute) and cexpr.func.attr == "Kernel":
eval_res = self._try_eval(cexpr.func)
from tilelang.language import Kernel

if eval_res is Kernel:
is_kernel_ctx = True
node = self.generic_visit(node)
for expr in node.items:
expr.context_expr = quote_expr("__tb.ctx_with(e)", e=expr.context_expr, span=expr)
if is_kernel_ctx:
return [quote1("if __tb.skip_kernel_ctx(): return"), node]
return node

def visit_Assert(self, node: ast.Assert):
Expand Down
97 changes: 65 additions & 32 deletions tilelang/language/eager/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def is_var(v: Any) -> bool:
return isinstance(v, Buffer) and v.scope() == "local.var"


# phase1: eager jit obtain function signature
# phase2: eager jit elaborate function
# none: not inside eager jit, i.e. it is lazyjit
EagerJITStage = Literal["phase1", "phase2", "none"]


class Builder(BaseBuilder):
def __init__(self):
self.frames: list[AnyFrame] = []
Expand All @@ -173,7 +179,8 @@ def __init__(self):
self.out_idx = []
self.out_tensor_cnt = 0
self.constexpr_var = set()
self.eager_jit = False
self.eager_jit: EagerJITStage = "none"
self.eager_jit_subs: dict[str, PrimExpr] = {}
self.current_file = "<unknown>"
self.current_line = 0
self.current_macro_name = "<unknown-macro>"
Expand All @@ -192,7 +199,7 @@ def prim_func(self, name):
with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name)
yield
if len(self.out_idx) != self.out_tensor_cnt:
if self.eager_jit != "phase1" and len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError("Not all tensor allocated from `T.empty` are returned")
finally:
del thread_local_storage.builder
Expand Down Expand Up @@ -597,6 +604,8 @@ def ret(self, value=None):
)
return value
else:
if self.eager_jit == "phase1":
return NotImplemented
if not isinstance(value, tuple):
value = (value,)
for v in value:
Expand Down Expand Up @@ -683,6 +692,7 @@ def override(self, name: str):
def constexpr(self, name: str, dtype: str = "int32") -> Var:
var = tir.Var(name, dtype)
self.constexpr_var.add(var)
var.orig_name = name
return var

def set_fileline(self, filename: str, lineno: int, name: str):
Expand All @@ -694,6 +704,9 @@ def get_fileline_stack(self, stacklevel=1):
stack = self.macro_fileline_stack + [(self.current_file, self.current_line, self.current_macro_name)]
return stack[: len(stack) - stacklevel + 1]

def skip_kernel_ctx(self):
return self.eager_jit == "phase1"


_P = ParamSpec("_P")
_T = TypeVar("_T")
Expand Down Expand Up @@ -866,17 +879,29 @@ def kernel(A, B):
builder = Builder.current()
# assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)"
# assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
if builder is None or not builder.eager_jit:
if builder is None or builder.eager_jit == "none":
raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)")

if "," in name:
names = re.split(r"\s*,\s*", name)
return tuple(builder.constexpr(n, dtype) for n in names)
if " " in name:
names = re.split(r"\s+", name)
return tuple(builder.constexpr(n, dtype) for n in names)
else:
return builder.constexpr(name, dtype)
if builder.eager_jit == "phase1":
# in stage 1, we create constexpr variables
if "," in name:
names = re.split(r"\s*,\s*", name)
return tuple(builder.constexpr(n, dtype) for n in names)
if " " in name:
names = re.split(r"\s+", name)
return tuple(builder.constexpr(n, dtype) for n in names)
else:
return builder.constexpr(name, dtype)
elif builder.eager_jit == "phase2":
# in stage 2, we substitute constexpr variables with actual values
if "," in name:
names = re.split(r"\s*,\s*", name)
return tuple(builder.eager_jit_subs[n] for n in names)
if " " in name:
names = re.split(r"\s+", name)
return tuple(builder.eager_jit_subs[n] for n in names)
else:
return builder.eager_jit_subs[name]


@dataclass
Expand All @@ -889,12 +914,17 @@ class TirTemplate(Generic[_P, _T]):
actual tensor shapes at runtime.
"""

name: str
prim_func: PrimFunc[_P, _T]
matcher: dict[Var, tuple[tvm.tir.Var, str, int, str]] | None = None
constexprs: set[Var] = None
is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly)
ir_gen: IRGenerator[_P, _T] | None = None

@classmethod
def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate[_P, _T]:
def create(
cls, name: str, prim_func: PrimFunc[_P, _T], constexpr: set[Var], ir_gen: IRGenerator[_P, _T] | None = None
) -> TirTemplate[_P, _T]:
matcher = {}
for k, v in prim_func.buffer_map.items():
for i, s in enumerate(v.shape):
Expand All @@ -915,12 +945,13 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate
f"Buffer shapes: {shapes}\n"
f"Buffer strides: {strides}"
)
return cls(prim_func=prim_func, matcher=matcher, is_lazy_style=False)
matcher = {k: matcher[k] for k in constexpr}
return cls(name=name, prim_func=prim_func, matcher=matcher, constexprs=constexpr, is_lazy_style=False, ir_gen=ir_gen)

@classmethod
def from_lazy_style(cls, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]:
def from_lazy_style(cls, name: str, prim_func: PrimFunc[_P, _T]) -> TirTemplate[_P, _T]:
"""Create template from lazy-style function that returns PrimFunc directly."""
return cls(prim_func=prim_func, is_lazy_style=True)
return cls(name=name, prim_func=prim_func, is_lazy_style=True)

def _parse_phase2_key(self, **kwargs):
if self.matcher is None:
Expand All @@ -946,16 +977,20 @@ def _parse_phase2_key(self, **kwargs):
)
return tuple(result)

def get_tir(self, **kwargs):
def get_tir(self, tensor_args, given_tensor_args, kwargs):
if self.is_lazy_style:
return self.prim_func
values = self._parse_phase2_key(**kwargs)
subs = {name: value for name, value in zip(self.matcher, values)}
result = substitute_primfunc(self.prim_func, subs)
result.orig_func = self.prim_func.orig_func
if hasattr(self.prim_func, "out_idx_override"):
result.out_idx_override = self.prim_func.out_idx_override
return result
values = self._parse_phase2_key(**given_tensor_args, **kwargs)
subs = {name.orig_name: value for name, value in zip(self.matcher, values)}
builder = Builder()
builder.eager_jit = "phase2"
builder.eager_jit_subs = subs
with builder.prim_func(self.name):
self.ir_gen.gen(builder)(**tensor_args, **kwargs)
pf = builder.get()
if builder.out_idx:
pf.out_idx_override = builder.out_idx
return pf


@dataclass
Expand Down Expand Up @@ -1021,7 +1056,7 @@ def foo(A, B):
# lazy jit must return PrimFunc
if isinstance(prim_func, PrimFunc):
p1_key, _, _ = self._parse_phase1_key(*args, **kwargs)
self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func)
self.p1_cache[p1_key] = TirTemplate.from_lazy_style(self.orig_func.__name__, prim_func)
return True
return False
except (JITNoBuilderError, EagerJITBuildError, TypeError):
Expand All @@ -1036,18 +1071,18 @@ def _build_tir_template(self, *args, **kwargs) -> TirTemplate[_P, _T]:
"""Build TIR template based on the execution mode."""
if self.mode == "lazy":
# lazy: function returns PrimFunc directly
return TirTemplate.from_lazy_style(self.orig_func(*args, **kwargs))
return TirTemplate.from_lazy_style(self.orig_func.__name__, self.orig_func(*args, **kwargs))
elif self.mode == "eager":
# eager: trace function body through Builder to construct TIR
builder = Builder()
builder.eager_jit = True
builder.eager_jit = "phase1"
with builder.prim_func(self.orig_func.__name__):
self.ir_gen.gen(builder)(**self.tensor_args, **kwargs)
pf = builder.get()
pf.orig_func = self.orig_func
if builder.out_idx:
pf.out_idx_override = builder.out_idx
return TirTemplate.create(pf, builder.constexpr_var)
return TirTemplate.create(self.orig_func.__name__, pf, builder.constexpr_var, self.ir_gen)
else:
raise ValueError(f"Invalid jit mode: {self.mode}, expected 'lazy' or 'eager'")

Expand All @@ -1061,17 +1096,15 @@ def parse_args(self, *args, **kwargs):
# mode should be set by JITImpl before calling parse_args
tir_temp = self._build_tir_template(**kwargs)
self.p1_cache[p1_key] = tir_temp
p2_key = tir_temp._parse_phase2_key(**tensor_args)
p2_key = tir_temp._parse_phase2_key(**tensor_args, **kwargs)
return (p1_key, p2_key), tensor_args

def get_tir(self, *args, **kwargs):
p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs)
if p1_key not in self.p1_cache:
# in legacy gemm, we use lazy tir template to build the tir
tir_temp = self._build_tir_template(**kwargs)
self.p1_cache[p1_key] = tir_temp
return tir_temp.get_tir(**tensor_args, **kwargs)
return self.p1_cache[p1_key].get_tir(**tensor_args, **kwargs)
self.p1_cache[p1_key] = self._build_tir_template(**kwargs)
return self.p1_cache[p1_key].get_tir(self.tensor_args, tensor_args, kwargs)

def __call__(self, *args, **kwargs):
return self.get_tir(*args, **kwargs)
Expand Down
Loading