Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,5 +361,39 @@ def test_while_loop(A: T.Tensor((1,), T.int32)):
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"


def test_var_macro():
try:

@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841

@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = T.alloc_var(T.int32)
macro_with_var(x)

assert 'x[0] = 1' in prim_call_macro.script()
finally:
pass

try:

@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841

@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = 1
macro_with_var(x)

raise RuntimeError("Expect to report an error, x should not be passed as T.Var")
except ValueError:
pass


if __name__ == '__main__':
tilelang.testing.main()
63 changes: 33 additions & 30 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {}
self.arg_annotations = {}

@classmethod
def current(cls) -> Self:
Expand All @@ -155,16 +156,17 @@ def prim_func(self, name):
yield

@contextmanager
def macro(self, name=None):
def macro(self, name=None, annotations=None):
if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame
save = self.name_inside_frame, self.arg_annotations
self.name_inside_frame = {}
self.arg_annotations = annotations or {}
with self.with_frame(MacroFrame()):
yield
self.name_inside_frame = save
self.name_inside_frame, self.arg_annotations = save

def get(self):
return self.ir_builder.get()
Expand Down Expand Up @@ -313,32 +315,18 @@ def bind(self, name, value, annot=BaseBuilder.empty):
self.check_continue_break()
locals = self.get_parent_locals()
orig_value = locals.get(name, None)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is also a local.var, this is a new binding,
# however, if rvalue is not a PrimExpr, such as buffer,
# we should not use buffer_store, and bind it instead
# ```py
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# a = tl.alloc_shared((1,), T.float32) # bind a to new buffer
# b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# ```
if is_var(orig_value) and not is_var(value):
if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)):
tir.buffer_store(orig_value, value, 0)
return orig_value
res = self.bind_immutable(name, value)
Expand Down Expand Up @@ -486,22 +474,34 @@ def rval(self, name: str, value: Any) -> Any:
)
return self.unwrap_value(value)

def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
if isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
def macro_arg(self, name, value):
if self.arg_annotations.get(name, None) is Var:
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
if not is_var:
raise ValueError(
f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})'
)
return value.buffer
elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value

def prim_func_arg(self, name, value):
if isinstance(value, (Buffer, Var)):
return tir.arg(name, value)
elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated')
# elif isinstance(value, Hashable):
# return value
else:
raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")

def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
return self.macro_arg(name, value)
else:
return self.prim_func_arg(name, value)

def override(self, name: str):
from tilelang.language import serial
if name == 'range':
Expand Down Expand Up @@ -533,14 +533,15 @@ class Macro(Generic[_P, _T]):
name: str
orig_func: Callable[_P, _T]
ir_gen: IRGenerator[_P, _T]
annotations: dict[str, Any]

@property
def source(self) -> str:
return self.ir_gen.source

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current()
with builder.macro(self.name):
with builder.macro(self.name, self.annotations):
res = self.ir_gen.gen(builder)(*args, **kwargs)
return res

Expand Down Expand Up @@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""

def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func))
annotations = get_type_hints(func)
return Macro(
name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations)

return impl(func) if func is not None else impl

Expand Down
Loading