Skip to content

Commit ee3135f

Browse files
author
Siyuan Feng
committed
Support TVMScript meta-programming
1 parent effc23d commit ee3135f

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

python/tvm/script/context_maintainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class ContextMaintainer:
121121
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
122122
symbols: List[Dict[str, Union[Var, Buffer]]] = []
123123
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
124+
closure_vars: Dict[str, Object] = {}
125+
"""ClosureVars: The closure vars defined in Python interpreter"""
124126

125127
# function context
126128
func_params: List[Var] = []
@@ -144,12 +146,17 @@ class ContextMaintainer:
144146
root_alloc_buffers: List[Buffer] = []
145147
"""List[Buffer]: The buffers allocated under root block"""
146148

147-
def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
149+
def __init__(
150+
self,
151+
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
152+
closure_vars: Dict[str, Object],
153+
):
148154
# scope context
149155
self.node_stack = []
150156
self.block_info_stack = []
151157
self.loop_stack = {}
152158
self.symbols = []
159+
self.closure_vars = closure_vars
153160
# function context
154161
self.func_params = []
155162
self.func_buffer_map = {}
@@ -233,7 +240,7 @@ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
233240
for symbols in reversed(self.symbols):
234241
if name in symbols:
235242
return symbols[name]
236-
return None
243+
return self.closure_vars.get(name)
237244

238245
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
239246
self._report_error(message, span)

python/tvm/script/parser.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,21 @@ class TVMScriptParser(Transformer):
158158

159159
# pylint gets confused here with synr.Transformer which doesn't have a
160160
# custom init, so just disable it
161-
def __init__(self, base_lineno, tir_namespace): # pylint: disable=super-init-not-called
161+
def __init__(
162+
self, base_lineno, tir_namespace, closure_vars
163+
): # pylint: disable=super-init-not-called
162164
self.context = None
163165

164166
self.base_lineno = base_lineno
165167
self.current_lineno = 0
166168
self.current_col_offset = 0
167169
self.tir_namespace = tir_namespace
170+
self.closure_vars = closure_vars
168171
self.meta = None
169172

170173
def init_function_parsing_env(self):
171174
"""Initialize function parsing environment"""
172-
self.context = ContextMaintainer(self.report_error) # scope emitter
175+
self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter
173176

174177
def init_meta(self, meta_dict):
175178
if meta_dict is not None:
@@ -1252,12 +1255,14 @@ def from_source(
12521255
"""
12531256
if isinstance(input_func, str):
12541257
tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
1255-
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix))
1258+
return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
12561259
elif inspect.isfunction(input_func):
12571260
_, start_line = inspect.getsourcelines(input_func)
12581261
env: Dict[str, Any] = input_func.__globals__
12591262
namespace = [key for key in env.keys() if env[key] is tir]
1260-
parser = TVMScriptParser(start_line, namespace)
1263+
_closure_vars = inspect.getclosurevars(input_func)
1264+
closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
1265+
parser = TVMScriptParser(start_line, namespace, closure_vars)
12611266
result = to_ast(input_func, TVMDiagnosticCtx(), parser)
12621267
return result
12631268
else:
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
from tvm.script import tir as T
20+
21+
22+
def matmul_generator(M: int, N: int, K: int, dtype: str):
23+
@T.prim_func
24+
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
25+
A = T.match_buffer(a, [M, K], dtype=dtype)
26+
B = T.match_buffer(b, [N, K], dtype=dtype)
27+
C = T.match_buffer(c, [M, N], dtype=dtype)
28+
29+
for i, j, k in T.grid(M, N, K):
30+
with T.block():
31+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
32+
with T.init():
33+
C[vi, vj] = T.float32(0)
34+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
35+
36+
return matmul
37+
38+
39+
@T.prim_func
40+
def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
41+
A = T.match_buffer(a, [128, 128], dtype="float16")
42+
B = T.match_buffer(b, [128, 128], dtype="float16")
43+
C = T.match_buffer(c, [128, 128], dtype="float16")
44+
45+
for i, j, k in T.grid(128, 128, 128):
46+
with T.block():
47+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
48+
with T.init():
49+
C[vi, vj] = T.float32(0)
50+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
51+
52+
53+
def test_meta_programming_matmul():
54+
f = matmul_generator(128, 128, 128, "float16")
55+
tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)
56+
57+
58+
if __name__ == "__main__":
59+
test_meta_programming_matmul()

0 commit comments

Comments
 (0)