Skip to content

Commit bf90a5f

Browse files
[Fix] Fix frame scope error in T.macro (#1308)
* [Fix] Fix #1307 by adding macro inside function * fix lint error * add comments and fix lint error * Remove debug print from enter_frame method Removed debug print statement from enter_frame method. --------- Co-authored-by: Lei Wang <[email protected]>
1 parent 17bbc0c commit bf90a5f

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

testing/python/language/test_tilelang_language_frontend_v2.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,5 +427,31 @@ def prim_call_macro():
427427
pass
428428

429429

430+
def frame_inside_macro():
431+
432+
@tilelang.jit
433+
def get_sample_kernel():
434+
435+
@T.macro
436+
def transform(x):
437+
return x + 1
438+
439+
@T.prim_func
440+
def sample_kernel(
441+
num_blocks: T.int32,
442+
idx_out: T.Tensor[(32,), T.int32],
443+
):
444+
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
445+
fragment = T.alloc_fragment(32, 'int32')
446+
T.copy(idx_out, fragment)
447+
448+
for i in T.Parallel(32):
449+
idx_out[i] = transform(fragment[i])
450+
451+
return sample_kernel
452+
453+
kernel = get_sample_kernel() # noqa: F841
454+
455+
430456
if __name__ == '__main__':
431457
tilelang.testing.main()

tilelang/language/v2/builder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class MacroFrame(Frame):
8080
...
8181

8282

83+
class ExitedMacroFrame(Frame):
84+
...
85+
86+
8387
class BoolOpFrame(Frame):
8488
...
8589

@@ -164,8 +168,22 @@ def macro(self, name=None, annotations=None):
164168
save = self.name_inside_frame, self.arg_annotations
165169
self.name_inside_frame = {}
166170
self.arg_annotations = annotations or {}
167-
with self.with_frame(MacroFrame()):
168-
yield
171+
pos = len(self.frames)
172+
# here we add a ExitedMacroFrame to preserve the frame stack inside macro
173+
# because macro may bind some variable, and return it
174+
#
175+
# ```py
176+
# @T.macro
177+
# def foo(x):
178+
# y = x + 1
179+
# return y
180+
# @T.prim_func
181+
# def bar():
182+
# c = foo(1) # macro generates let y = x + 1
183+
# d = c # d = c should lay inside frame of `let y = x + 1`
184+
self.frames.append(MacroFrame())
185+
yield
186+
self.frames[pos] = ExitedMacroFrame()
169187
self.name_inside_frame, self.arg_annotations = save
170188

171189
def get(self):

0 commit comments

Comments
 (0)