Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse, parse_macro
from .core.entry import parse, scan_macro
from .core.parser import Parser
6 changes: 2 additions & 4 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@ def _default_globals() -> Dict[str, Any]:
return extra_vars


def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
"""Generate the AST, and the source code for __repr__."""
# The AST will be converted into TIR at the time of expansion.
source = Source(program)
source_txt = source.source
source_ast = source.as_ast()
closure_vars = extra_vars or _default_globals()
return source_ast, source_txt, closure_vars
return source, closure_vars


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
Expand Down
105 changes: 105 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""The core parser"""

import abc
import inspect
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -65,6 +67,108 @@ def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument
pass


class ScriptMacro(abc.ABC):
"""Representation of a script macro.

This is a callable object, intended to be called from the expression evaluator.
The evaluator is expected to insert the current parser into the environment
undef the name given by "parser_object_name".

Once called, the ScriptMacro object will locate the current parser, and use it
to parse the macro's body and produce the result.

There were two major considerations for this design:
1. Implementing hygienic and non-hygienic macros.
2. Implementing macros that return values.

Macro uses in TIR are only allowed at a statement-level, and they don't produce
any values. Parsing of such macros could easily be done by intercepting doc.Call
nodes in the TIR parser. If a macro is a value-producing expression, then there
may not be a direct way to intercept calls to it if it's embedded in a complex
expression. Because macros use function-call syntax, the evaluator will try to
call the macro object, which this design relies on to parse and evaluate the macro.
"""

parser_object_name = "__current_script_parser__"

def __init__(
self,
source: Source,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source = source
self.closure_vars = closure_vars
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source.source

@abc.abstractmethod
def parse_macro(self, parser: "Parser") -> Any:
"""The main macro parsing function. Different scripts may have different
ways to parse a macro, and to return a value to the evaluator.

Parameters
----------
parser : Parser
The parser with the appropriate frame already created and populated depending
macro's hygiene settings,

Returns
-------
The return value depends on the specifics of the particular script. It can be
"None" or any other value or any type.
"""

def _find_parser_def(self):
outer_frame_infos = inspect.getouterframes(inspect.currentframe())
for finfo in outer_frame_infos:
parser = finfo.frame.f_globals.get(ScriptMacro.parser_object_name)
if parser is not None:
return parser
raise RuntimeError(f"{ScriptMacro.parser_object_name} not available")

def get_macro_def(self):
ast_module = self.source.as_ast()
for decl in ast_module.body:
if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__:
return decl
raise RuntimeError(f"cannot find macro definition for {self.__name__}")

def __call__(self, *args, **kwargs):
param_binding = inspect.signature(self.func).bind(*args, **kwargs)
param_binding.apply_defaults()
local_vars = param_binding.arguments
parser = self._find_parser_def()

if self.hygienic:
saved_var_table = parser.var_table
parser.var_table = VarTable()

with parser.var_table.with_frame():
for k, v in self.closure_vars.items():
parser.var_table.add(k, v)
for k, v in local_vars.items():
parser.var_table.add(k, v)

parse_result = self.parse_macro(parser)

parser.var_table = saved_var_table

else:
with parser.var_table.with_frame():
for k, v in local_vars.items():
parser.var_table.add(k, v)

print(parser.var_table.get())
parse_result = self.parse_macro(parser)

return parse_result


class VarTableFrame:
"""The variable table frame.
A frame of variable table stores the variables created in one block or scope.
Expand Down Expand Up @@ -326,6 +430,7 @@ def eval_expr(
if extra_vars is not None:
for k, v in extra_vars.items():
var_values[k] = v
var_values[ScriptMacro.parser_object_name] = self
return eval_expr(self, node, var_values)

def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
Expand Down
41 changes: 10 additions & 31 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
"""The entry point of TVM parser for tir."""
import inspect
from typing import Any, Callable, Dict, Optional, Union
from typing import Callable, Optional, Union

from tvm.ir.base import deprecated
from tvm.tir import Buffer, PrimFunc

from ...ir_builder.tir import buffer, ptr
from .._core import doc, parse, parse_macro, utils
from .._core import parse, scan_macro, utils
from ..core.parser import Parser, ScriptMacro


def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]:
Expand Down Expand Up @@ -86,25 +87,12 @@ def decorator_wrapper(func):
# inserted at the point where the call to the macro is located.


class TIRMacro:
"""Representation of T.macro."""
class TIRMacro(ScriptMacro):
"""Specialization of the ScriptMacro class for TIR."""

def __init__(
self,
source_ast: doc.AST,
source_txt: str,
closure_vars: Dict[str, Any],
func: Callable,
hygienic: bool,
) -> None:
self.source_ast = source_ast
self.source_txt = source_txt
self.closure_vars = closure_vars
self.func = func
self.hygienic = hygienic

def __repr__(self):
return self.source_txt
def parse_macro(self, parser: Parser) -> None:
macro_def = self.get_macro_def()
parser.visit_body(macro_def.body)


def macro(*args, hygienic: bool = True) -> Callable:
Expand Down Expand Up @@ -147,15 +135,9 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
"""

def _decorator(func: Callable) -> TIRMacro:
source_ast, source_txt, closure_vars = parse_macro(
func, utils.inspect_function_capture(func)
)
obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
# We don't need to explicitly store the return value anywhere.
# This function is a decorator, so the return value will replace
# the function definition (to which the decorator it is applied)
# in that function's name space.
return obj

if len(args) == 0:
Expand All @@ -168,9 +150,6 @@ def _decorator(func: Callable) -> TIRMacro:
)


# There is no dispatch_token for macro, because macro doesn't invoke parser.


class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""

Expand Down
59 changes: 1 addition & 58 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
"""The base parser for tir"""

import contextlib
import inspect
from functools import partial
from typing import Any, Union
from typing import Any

import tvm
from tvm.ir import GlobalVar, PrimType
Expand All @@ -30,8 +29,6 @@
from ...ir_builder.base import IRBuilder
from ...ir_builder.base import IRBuilderFrame as Frame
from .._core import Parser, dispatch, doc
from ..core.parser import VarTable
from .entry import TIRMacro


def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
Expand Down Expand Up @@ -447,11 +444,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
The doc AST Expr node.
"""

if isinstance(node.value, doc.Call):
callee = self.eval_expr(node.value.func)
if isinstance(callee, TIRMacro):
return expand_macro(self, callee, node.value)

res = self.eval_expr(node.value)
if res is None:
pass
Expand All @@ -472,7 +464,6 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
pass
else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
return None # For pylint


@dispatch.register(token="tir", type_name="If")
Expand Down Expand Up @@ -554,51 +545,3 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)


def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None:
"""Bind arguments to the macro invocation to the parameters in the macro definition,
and pass the macro body for further parsing.
"""

assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}"

def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
for decl in decl_list:
if isinstance(decl, doc.FunctionDef) and decl.name == name:
return decl
return None

macro_def = find_macro_def(callee.__name__, callee.source_ast.body)
assert macro_def is not None, f"Invalid macro AST for {callee.__name__}"
# `macro_def` is the FunctionDef of the macro.

args = [self.eval_expr(arg) for arg in call.args]
kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords}
param_binding = inspect.signature(callee.func).bind(*args, **kwargs)
param_binding.apply_defaults()
local_vars = param_binding.arguments

if callee.hygienic:
# If the macro was hygienic, construct new var_table with a single frame that
# contains the captured environment, and process the macro's body with that
# frame.
saved_var_table = self.var_table
self.var_table = VarTable()
with self.var_table.with_frame():
for k, v in callee.closure_vars.items():
self.var_table.add(k, v)
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)

self.var_table = saved_var_table

else:
# Otherwise, dynamically resolve symbols in the macro's body.
with self.var_table.with_frame():
for k, v in local_vars.items():
self.var_table.add(k, v)

self.visit_body(macro_def.body)