From 69b08fb6e7fae8387161466ec53e2828f10a8cec Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 12:11:36 -0700 Subject: [PATCH 01/31] Rename converter Signed-off-by: Justin Chu --- onnxscript/{converter.py => _converter.py} | 17 ++++++----------- .../{converter_test.py => _converter_test.py} | 6 +++--- onnxscript/_internal/autocast.py | 10 +++++----- onnxscript/main.py | 4 ++-- onnxscript/values.py | 5 ++--- 5 files changed, 18 insertions(+), 24 deletions(-) rename onnxscript/{converter.py => _converter.py} (99%) rename onnxscript/{converter_test.py => _converter_test.py} (99%) diff --git a/onnxscript/converter.py b/onnxscript/_converter.py similarity index 99% rename from onnxscript/converter.py rename to onnxscript/_converter.py index dfcddefbd3..af5faa4ccc 100644 --- a/onnxscript/converter.py +++ b/onnxscript/_converter.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Python-to-IR converter""" from __future__ import annotations import ast @@ -17,25 +18,23 @@ ) import onnx +import onnx_ir as ir import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values from onnxscript import type_annotation as ta from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation -logger = logging.getLogger("onnxscript") +logger = logging.getLogger(__name__) -# Python-to-IR converter: - def not_allowed(construct): return f"{construct}not supported." -class TranslationError(Exception): - def __init__(self, *args: object) -> None: - super().__init__(*args) +class TranslationError(RuntimeError): + pass def warn(msg): @@ -139,10 +138,6 @@ def __str__(self) -> str: class Converter: """Main class to translate python code into ONNX operators. - Args: - ir_builder: convert AST node into ONNX structures, if None, - class :class:`onnxscript.irbuilder.IRBuilder` is used - The class uses logger `onnxscript`. Logging can be enabled with the following code: :: @@ -169,7 +164,7 @@ def __init__( source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self.ir_builder = ir_builder or irbuilder.IRBuilder() + self._model = ir. self.source = source if global_names is not None: # We make a copy in case function eval modifies it. diff --git a/onnxscript/converter_test.py b/onnxscript/_converter_test.py similarity index 99% rename from onnxscript/converter_test.py rename to onnxscript/_converter_test.py index 9a7ca504a7..ff8aaca591 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/_converter_test.py @@ -23,7 +23,7 @@ import onnxscript import onnxscript.testing -from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor +from onnxscript import BOOL, FLOAT, INT64, _converter, graph, script, tensor from onnxscript.onnx_opset import opset11 as op11 from onnxscript.onnx_opset import opset15 as op from tests.common import onnx_script_test_case, testutils @@ -437,12 +437,12 @@ def check_failure(self, f, msg): global_names = globals().copy() top_level_ast = ast.parse(source) f_ast = top_level_ast.body[0] - cvt = converter.Converter( + cvt = _converter.Converter( opset=op, global_names=global_names, source=source, default_opset=op ) try: cvt.translate_function_def(f_ast) - except converter.TranslationError as e: + except _converter.TranslationError as e: if msg not in str(e): raise AssertionError(f"Unable to find {msg!r} in {e!r} in\n{source}") from e return diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 1defac3e53..d9ad48af35 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -13,7 +13,7 @@ from onnxscript import ir, tensor if TYPE_CHECKING: - from onnxscript import converter + from onnxscript import _converter # Conversions from python values to ONNX are used by both the script converter as well # as the eager-mode runtime and both need to be consistent. The script converter converts @@ -187,16 +187,16 @@ def get_type_info(x): def static_cast_inputs( - converter_: converter.Converter, + converter_: _converter.Converter, op_schema: Optional[OpSchema], - args: Sequence[Optional[converter.Variable]], + args: Sequence[Optional[_converter.Variable]], ) -> tuple[str, ...]: """Used for autocast during script-translation. This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))" Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed. """ - def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]: + def get_type_info(x: Optional[_converter.Variable]) -> Optional[_converter.Variable]: """Returns x back if x can serve as the target-type for a cast (as the second argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. @@ -204,7 +204,7 @@ def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variabl return None if x is None or x.is_castable else x def cast_like( - x: Optional[converter.Variable], y: Optional[converter.Variable] + x: Optional[_converter.Variable], y: Optional[_converter.Variable] ) -> Optional[str]: if x is None: return None diff --git a/onnxscript/main.py b/onnxscript/main.py index 3ea3e50f90..15d8247530 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -11,7 +11,7 @@ from typing_extensions import ParamSpec import onnxscript -from onnxscript import converter, ir, irbuilder, values +from onnxscript import _converter, ir, irbuilder, values from onnxscript._internal import ast_utils _R = TypeVar("_R") @@ -29,7 +29,7 @@ def script_check( # See if conversion succeeds. # TODO: cleanup Converter interface/API, separating checker from # converter - convert = converter.Converter( + convert = _converter.Converter( opset=opset, global_names=global_names, source=source, diff --git a/onnxscript/values.py b/onnxscript/values.py index 1897ae14d5..d957145281 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -25,8 +25,7 @@ import onnx.defs from typing_extensions import ParamSpec -from onnxscript import converter as converter_module -from onnxscript import irbuilder, sourceinfo, type_annotation +from onnxscript import _converter, irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas @@ -638,7 +637,7 @@ def function_ir(self) -> irbuilder.IRFunction: closure = inspect.getclosurevars(self.func) global_names = module.__dict__.copy() global_names.update(closure.nonlocals) - converter = converter_module.Converter( + converter = _converter.Converter( opset=self._opset, global_names=global_names, source=src, From 7e0a7673f9708faa5dc39b830d3982c414506fda Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 12:15:24 -0700 Subject: [PATCH 02/31] Update fields Signed-off-by: Justin Chu --- onnxscript/_converter.py | 124 ++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 60 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index af5faa4ccc..065ba9a266 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -25,6 +25,38 @@ from onnxscript import type_annotation as ta from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation +if TYPE_CHECKING: + # The type-alias LocalSymValue represents the types of values that local names in a + # script-function may be bound to during translation, (ONNX IR values). + # TODO(rama): Rationalize this and values.SymbolValue + + LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction] + + # The type-alias PyValue is used to represent the types of python values that may be used + # in an ONNX Script function. + # TODO(rama): Flesh out the set of valid types here. These include values such as + # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for + # use as value-parameters or attribute-parameters in an ONNX call (Node). + + PyValue = Any + + # The type-alias SymValue denotes values that an identifier may be bound to during + # translation. A local name will be bound to a LocalSymValue, while a global name + # will be bound to a PyValue. + + SymValue = Union[LocalSymValue, PyValue] + + # PreferredName is a type-alias used to represent the preferred name used in the generated + # ONNX for a value returned by an expression. There is no guarantee that the specified + # name will be used exactly. The converter will modify the name (with a suffix), + # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). + + PreferredName = str + + # The type-alias OnnxVar indicates variable names used in the generated ONNX. + OnnxVarName = str + + logger = logging.getLogger(__name__) @@ -56,7 +88,7 @@ def ignore(cond, msg): # map from python operators to ONNX ops -primop_map = { +_PRIMOP_MAP = { ast.Add: "Add", ast.And: "And", ast.BitAnd: "And", @@ -103,36 +135,7 @@ def __str__(self) -> str: return self.name -if TYPE_CHECKING: - # The type-alias LocalSymValue represents the types of values that local names in a - # script-function may be bound to during translation, (ONNX IR values). - # TODO(rama): Rationalize this and values.SymbolValue - - LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction] - # The type-alias PyValue is used to represent the types of python values that may be used - # in an ONNX Script function. - # TODO(rama): Flesh out the set of valid types here. These include values such as - # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for - # use as value-parameters or attribute-parameters in an ONNX call (Node). - - PyValue = Any - - # The type-alias SymValue denotes values that an identifier may be bound to during - # translation. A local name will be bound to a LocalSymValue, while a global name - # will be bound to a PyValue. - - SymValue = Union[LocalSymValue, PyValue] - - # PreferredName is a type-alias used to represent the preferred name used in the generated - # ONNX for a value returned by an expression. There is no guarantee that the specified - # name will be used exactly. The converter will modify the name (with a suffix), - # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). - - PreferredName = str - - # The type-alias OnnxVar indicates variable names used in the generated ONNX. - OnnxVarName = str class Converter: @@ -158,19 +161,20 @@ class Converter: def __init__( self, - ir_builder: Optional[irbuilder.IRBuilder] = None, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self._model = ir. - self.source = source + self._source = source if global_names is not None: # We make a copy in case function eval modifies it. - self.globals = global_names.copy() - self.this_module = opset - self.default_opset_ = default_opset + self._globals = global_names.copy() + self._this_module = opset + self._default_opset = default_opset + + # TODO(justinchuby): Update ir version to be user defined + self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) # States initialized by `_init_function_translation` self._outer: List[irbuilder.IRFunction] = [] @@ -181,26 +185,26 @@ def __init__( @property def default_opset(self) -> values.Opset: - if self.default_opset_ is None: + if self._default_opset is None: raise RuntimeError( "default_opset must be specified in script for functions " "that do not contain any use of an ONNX opset." ) - return self.default_opset_ + return self._default_opset def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: if opset.domain != "": return - if self.default_opset_ is not None: + if self._default_opset is not None: if ( - opset.domain != self.default_opset_.domain - or opset.version != self.default_opset_.version + opset.domain != self._default_opset.domain + or opset.version != self._default_opset.version ): self.fail( - node, f"Two distincts opset were used ({opset} != {self.default_opset_})." + node, f"Two distincts opset were used ({opset} != {self._default_opset})." ) else: - self.default_opset_ = opset + self._default_opset = opset def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: """Find the (first) ONNX opset used in the function, if any.""" @@ -209,8 +213,8 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: if isinstance(node.func, ast.Attribute): opset_expr = node.func.value if isinstance(opset_expr, ast.Name): - if opset_expr.id in self.globals: - opset = self.globals[opset_expr.id] + if opset_expr.id in self._globals: + opset = self._globals[opset_expr.id] if isinstance(opset, values.Opset) and opset.domain == "": return opset for child in ast.iter_child_nodes(node): @@ -228,7 +232,7 @@ def _init_function_translation(self) -> None: self._locals: List[Dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: - return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) + return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) def _message(self, node: ast.AST, error_msg: str) -> str: """Constructs an error _message containing source information about an ast node.""" @@ -277,8 +281,8 @@ def _lookup( for scope in self._locals: if name in scope: return scope[name] - if name in self.globals: - return self.globals[name] + if name in self._globals: + return self._globals[name] if raise_exception: raise ValueError(info.msg(f"Unbound name: {name}.")) return None @@ -452,12 +456,12 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) cpl = compile(expr, filename="", mode="eval") try: - return eval(cpl, self.globals, locals) # pylint: disable=eval-used + return eval(cpl, self._globals, locals) # pylint: disable=eval-used except NameError as e: raise NameError( self._message( expr, - f"Missing names, globals contains {list(self.globals)!r}, " + f"Missing names, globals contains {list(self._globals)!r}, " f"locals {list(locals)!r}.", ) ) from e @@ -838,7 +842,7 @@ def _cast_like_binary_expression(self, op, left, right): def _translate_binary_op_expr(self, node: ast.BinOp): op = type(node.op) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) attr = [] @@ -849,7 +853,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp): if isinstance(cst, float): attr = [self._make_onnx_attr("fmod", 1)] - op = values.Op(self.default_opset, primop_map[op]) + op = values.Op(self.default_opset, _PRIMOP_MAP[op]) left, right = self._cast_like_binary_expression( op, self._translate_expr(node.left), self._translate_expr(node.right) ) @@ -857,7 +861,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp): def _translate_unary_op_expr(self, node): op = type(node.op) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}.")) if self._is_constant_expr(node.operand): # This function changed the constant node.operand @@ -878,7 +882,7 @@ def _translate_unary_op_expr(self, node): return self._translate_expr(cst) if op == ast.UAdd: return self._translate_expr(node.operand) - opname = primop_map[op] + opname = _PRIMOP_MAP[op] operand = self._translate_expr(node.operand) return values.Op(self.default_opset, opname), [operand], [] @@ -887,9 +891,9 @@ def _translate_compare_expr(self, node): assert len(node.ops) == 1 assert len(node.comparators) == 1 op = type(node.ops[0]) - if op not in primop_map: + if op not in _PRIMOP_MAP: raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) - opname = primop_map[op] + opname = _PRIMOP_MAP[op] left = self._translate_expr(node.left) right = self._translate_expr(node.comparators[0]) @@ -1437,21 +1441,21 @@ def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFun def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if isinstance(stmt, ast.FunctionDef): self._init_function_translation() - if self.default_opset_ is None: + if self._default_opset is None: opset = self._find_onnx_opset(stmt) if opset: self._set_default_opset(opset, stmt) - domain = self.this_module.domain + domain = self._this_module.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) analysis.do_liveness_analysis(stmt, self._message) fn_ir = self._translate_function_def_common(stmt) fn_ir.debug_print() - self.this_module.add_function_def(fn_ir) + self._this_module.add_function_def(fn_ir) return fn_ir raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" - domain = self.this_module.domain + domain = self._this_module.domain self._current_fn = self.ir_builder.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn) From 8c7abc629421be5f9b7f1dac4c02ad86179f001d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 12:21:46 -0700 Subject: [PATCH 03/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 065ba9a266..79b773c73d 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Python-to-IR converter""" + from __future__ import annotations import ast @@ -60,7 +61,6 @@ logger = logging.getLogger(__name__) - def not_allowed(construct): return f"{construct}not supported." @@ -135,9 +135,6 @@ def __str__(self) -> str: return self.name - - - class Converter: """Main class to translate python code into ONNX operators. @@ -176,12 +173,12 @@ def __init__( # TODO(justinchuby): Update ir version to be user defined self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) - # States initialized by `_init_function_translation` - self._outer: List[irbuilder.IRFunction] = [] - self._current_fn: irbuilder.IRFunction = None + # A stack of functions in the outer scope + self._outer: list[ir.Function] = [] + self._current_fn: ir.Function | None = None self._nextvar: int = 0 self._used_vars: set[str] = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._locals: list[dict[str, LocalSymValue]] = [{}] @property def default_opset(self) -> values.Opset: @@ -226,12 +223,13 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: def _init_function_translation(self) -> None: """Initialize self for translating a new (top-level) function.""" self._outer = [] - self._current_fn: Optional[irbuilder.IRFunction] = None + self._current_fn = None self._nextvar = 0 self._used_vars = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: + assert self._current_fn is not None return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) def _message(self, node: ast.AST, error_msg: str) -> str: @@ -255,8 +253,15 @@ def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - self._outer.insert(0, self._current_fn) - self._current_fn = self.ir_builder.new_function(name) + assert self._current_fn is not None + self._outer.append(self._current_fn) + assert self._this_module is not None + self._current_fn = ir.Function( + domain=self._this_module.domain, + name=name, + graph=ir.Graph((), (), nodes=[]), + attributes={}, + ) self._locals.insert(0, {}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) @@ -264,7 +269,7 @@ def _exit_scope(self) -> irbuilder.IRFunction: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn - self._current_fn = self._outer.pop(0) + self._current_fn = self._outer.pop() self._locals.pop(0) return graph From 370a9f58a831eda55ffdf146d133d1585715b0b2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 12:23:15 -0700 Subject: [PATCH 04/31] make locals a stack Signed-off-by: Justin Chu --- onnxscript/_converter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 79b773c73d..4e47ac40dc 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -262,7 +262,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST): graph=ir.Graph((), (), nodes=[]), attributes={}, ) - self._locals.insert(0, {}) + self._locals.append({}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) def _exit_scope(self) -> irbuilder.IRFunction: @@ -270,20 +270,20 @@ def _exit_scope(self) -> irbuilder.IRFunction: logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn self._current_fn = self._outer.pop() - self._locals.pop(0) + self._locals.pop() return graph def _current_scope(self) -> Dict[str, LocalSymValue]: - return self._locals[0] + return self._locals[-1] def _bind(self, name: str, val: LocalSymValue) -> None: logger.debug("Converter:_bind:%s", name) - self._locals[0][name] = val + self._locals[-1][name] = val def _lookup( self, name: str, info: sourceinfo.SourceInfo, raise_exception: bool = True ) -> SymValue: - for scope in self._locals: + for scope in reversed(self._locals): if name in scope: return scope[name] if name in self._globals: @@ -1342,7 +1342,7 @@ def _translate_block( ) else: pv_val = None - for scope in self._locals: # TODO: skip _current_scope + for scope in reversed(self._locals): # TODO: skip _current_scope if pvar in scope: pv_val = scope[pvar] break From 10473a97534daf39713cbb2f21c8fbefd3dacc0c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 12:45:27 -0700 Subject: [PATCH 05/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 36 ++++++++++++++++++------------------ onnxscript/irbuilder.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 4e47ac40dc..f03a9850e8 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -172,6 +172,7 @@ def __init__( # TODO(justinchuby): Update ir version to be user defined self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) + self._tape = ir.tape.Tape(self._model.graph) # A stack of functions in the outer scope self._outer: list[ir.Function] = [] @@ -265,12 +266,13 @@ def _enter_scope(self, name: str, parent_node: ast.AST): self._locals.append({}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) - def _exit_scope(self) -> irbuilder.IRFunction: + def _exit_scope(self) -> ir.Function: """Exit from a control-flow block (a loop body or if-then-else branch).""" logger.debug("Converter:_exit_scope:%d", len(self._locals)) graph = self._current_fn self._current_fn = self._outer.pop() self._locals.pop() + assert graph is not None return graph def _current_scope(self) -> Dict[str, LocalSymValue]: @@ -301,17 +303,17 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - def _make_onnx_attr( - self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> irbuilder.IRAttributeValue: - def tensor_name_generator() -> str: - """Return name to be used for tensor, if we need to create one.""" - return self.generate_unique_name(f"attr_{attrname}") + # def _make_onnx_attr( + # self, attrname: str, attrval: Any, attrtype: int | None = None + # ) -> irbuilder.IRAttributeValue: + # def tensor_name_generator() -> str: + # """Return name to be used for tensor, if we need to create one.""" + # return self.generate_unique_name(f"attr_{attrname}") - proto = autocast.pyvalue_to_onnx_attribute( - attrname, attrval, tensor_name_generator, attrtype - ) - return self.ir_builder.make_attr(proto) + # proto = autocast.pyvalue_to_onnx_attribute( + # attrname, attrval, tensor_name_generator, attrtype + # ) + # return self.ir_builder.make_attr(proto) def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] @@ -369,18 +371,16 @@ def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Varia def emit( self, - outputs: Sequence[str], + outputs: Sequence[ir.Value], callee: values.Op | str, - inputs: Sequence[Optional[str]], - attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, + inputs: Sequence[ir.Value | None], + attrs: Any = None, ): if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: - attrs = [] - if sub_functions is None: - sub_functions = {} + attrs = {} + self.ir_builder.add_stmt( self._current_fn, outputs, diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index b4d378bd17..41550a5d86 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -182,6 +182,18 @@ def attr_proto(self) -> onnx.AttributeProto: class IRStmt: + """An IR statement (representing an operation). + + Details: + - `result`: A sequence of variable names that this statement assigns to. + - `callee`: The operation being called, represented as an instance of `values.Op`. + - `args`: A sequence of arguments to the operation, which can be variable names or + `None` for optional arguments. + - `attrs`: A sequence of attributes for the operation, represented as `IRAttributeValue` + instances. + - `sub_functions`: A dictionary of sub-functions that this statement may call, mapping + function names to `onnx.FunctionProto` instances. + """ def __init__( self, result: Sequence[str], From 3889e7050db92f5a11f0d1b862526bfd008dddc7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 17:02:31 -0700 Subject: [PATCH 06/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 149 +++++++++++++++++----------------- onnxscript/ir/_schemas.py | 4 +- onnxscript/type_annotation.py | 51 ++++-------- onnxscript/values.py | 3 +- 4 files changed, 95 insertions(+), 112 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index f03a9850e8..cf966ea50f 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -20,6 +20,7 @@ import onnx import onnx_ir as ir +from onnxscript.ir import _schemas import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values @@ -172,7 +173,6 @@ def __init__( # TODO(justinchuby): Update ir version to be user defined self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) - self._tape = ir.tape.Tape(self._model.graph) # A stack of functions in the outer scope self._outer: list[ir.Function] = [] @@ -316,86 +316,85 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: # return self.ir_builder.make_attr(proto) def _to_onnx_attr_ref( - self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] - ) -> irbuilder.IRAttributeValue: + self, val: values.AttrRef, info: sourceinfo.SourceInfo | None + ) -> ir.Attr: + """Convert an attribute reference to an ONNX ref attribute.""" pytype = val.typeinfo - attrtype = ta.pytype_to_attrtype(pytype) + attrtype = _schemas.get_attr_type(pytype) attrname = None - if attrtype is onnx.AttributeProto.FLOAT: + if attrtype is ir.AttributeType.FLOAT: attrname = "value_float" - elif attrtype is onnx.AttributeProto.INT: + elif attrtype is ir.AttributeType.INT: attrname = "value_int" - elif attrtype is onnx.AttributeProto.STRING: + elif attrtype is ir.AttributeType.STRING: attrname = "value_string" - elif attrtype is onnx.AttributeProto.INTS: + elif attrtype is ir.AttributeType.INTS: attrname = "value_ints" else: msg = f"Unsupported attribute type {pytype!r}." fail(info.msg(msg) if info else msg) - return self.ir_builder.make_attr_ref(attrname, val.value, pytype) + # TODO(justinchuby): What is the ref attr name? + return ir.RefAttr(attrname, val.value, attrtype) def _to_onnx_var( self, val: values.SymbolValue | PyValue, - target: Optional[PreferredName] = None, - info: Optional[sourceinfo.SourceInfo] = None, + target: PreferredName = "tmp", + *, + info: sourceinfo.SourceInfo, ) -> Variable: + """Convert a value to an ONNX variable.""" if isinstance(val, values.AttrRef): # promote attribute to value - result = self.generate_unique_name(target or "tmp") + result = self.generate_unique_name(target) attr = self._to_onnx_attr_ref(val, info) - self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr]) + self.emit("Constant", [], [result], [attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self.generate_unique_name(result + "_as_bool") - cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) - self.emit( - [result_as_bool], - values.Op(self.default_opset, "Cast"), - [result], - [cast_attr], - ) - return Variable(result_as_bool, True) - return Variable(result, True) + self.emit("Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)]) + return Variable(result_as_bool, castable=True) + return Variable(result, castable=True) + if isinstance(val, values.Dynamic): return Variable(val.value) + # Assume value is a python-value convertible to a tensor - # TODO: check if value is convertible to a TensorProto, so that we can - # produce a better error _message otherwise - return self._emit_const(val, target or "tmp", info) + return self._emit_const(val, target, info) def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: + """Convert a python variable to an ONNX variable.""" return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) def emit( self, - outputs: Sequence[ir.Value], - callee: values.Op | str, - inputs: Sequence[ir.Value | None], - attrs: Any = None, + op_type: str, + inputs: Sequence[str], + outputs: Sequence[str], + attrs: Sequence[ir.Attr] = (), + domain: str = "", ): - if not isinstance(callee, values.Op): - callee = values.Op(self.default_opset, callee) - if attrs is None: - attrs = {} - - self.ir_builder.add_stmt( - self._current_fn, - outputs, - callee, - inputs, - attrs, - sub_functions, + """Emit an ONNX operator with the given inputs, outputs, and attributes.""" + node = ir.Node( + domain=domain, + op_type=op_type, + inputs=[self._lookup(inp, self._source_of(inputs[0])) for inp in inputs], + attributes=attrs, + outputs=[self._lookup(out, self._source_of(outputs[0])) for out in outputs], ) + assert self._current_fn is not None + self._current_fn.append(node) def _emit_const( self, pyvalue: PyValue, - suggested_name: Optional[PreferredName], + suggested_name: PreferredName | None, info: sourceinfo.SourceInfo, ) -> Variable: + """Emit a constant value as an ONNX Constant node.""" + # Obtain a name for the constant if suggested_name is None: if isinstance(pyvalue, int): if pyvalue >= 0: @@ -411,14 +410,16 @@ def _emit_const( suggested_name = f"int64_m{abs(pyvalue[0])}_1d" else: suggested_name = "const" - ovar = self.generate_unique_name(suggested_name) + var_name = self.generate_unique_name(suggested_name) + + # Create a tensor from the python value try: - tensor = autocast.pyvalue_to_onnx_tensor(ovar, pyvalue) - except ValueError as e: + tensor = ir.tensor(pyvalue, name=var_name) + except Exception as e: fail(info.msg(str(e))) - attr = self._make_onnx_attr("value", tensor) - self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - return Variable(ovar, True) + + self.emit("Constant", [], [var_name], [ir.AttrTensor("value", tensor)]) + return Variable(var_name, True) def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" @@ -426,25 +427,6 @@ def _emit_copy(self, original_var: str, suggested_name: str) -> str: self.emit([new_var], "Identity", [original_var]) return new_var - def _is_constant_expr(self, node: ast.AST) -> None: - if isinstance(node, ast.UnaryOp): - return self._is_constant_expr(node.operand) - if isinstance( - node, - ( - ast.Call, - ast.BinOp, - ast.UnaryOp, - ast.Compare, - ast.Attribute, - ast.List, - ast.Load, - ast.Constant, - ), - ): - return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node)) - return False - def _eval_constant_expr(self, expr: ast.AST) -> PyValue: """Evaluates a sub-expression that is assumed to represent a constant value. The expression can refer only to global names (inherited from the scope @@ -455,7 +437,7 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: as divergence between eager-mode execution and evaluation of the ONNX function.) """ - # TODO: assert (self._is_constant_expr(expr)) + # TODO: assert (_is_constant_expr(expr)) # TODO: Refine types locals: dict[Any, Any] = {} expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) @@ -558,7 +540,7 @@ def _translate_expr( r = self._translate_name_expr(node) elif isinstance(node, ast.Subscript): r = self._translate_subscript_expr(node, target) - elif self._is_constant_expr(node): + elif _is_constant_expr(node): r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node)) else: raise ValueError( @@ -657,7 +639,7 @@ def translate_slice_component( ) return const_1d(default_value), default_value - if self._is_constant_expr(node_arg): + if _is_constant_expr(node_arg): cst = self._eval_constant_expr(node_arg) if isinstance(cst, int): return const_1d(cst), cst @@ -709,7 +691,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # Add to sliced_indices, unless it is "::", which is a no-op. if not (elt.lower is None and elt.upper is None and elt.step is None): sliced_indices.append((axis, elt)) - elif self._is_constant_expr(elt) and isinstance( + elif _is_constant_expr(elt) and isinstance( self._eval_constant_expr(elt), int ): scalar_indices.append((axis, elt)) @@ -851,7 +833,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp): raise ValueError(self._message(node, f"Unsupported operator {op!r}.")) attr = [] - if isinstance(node.op, ast.Mod) and self._is_constant_expr(node.right): + if isinstance(node.op, ast.Mod) and _is_constant_expr(node.right): # specific case X % f where f is a float. # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) @@ -868,7 +850,7 @@ def _translate_unary_op_expr(self, node): op = type(node.op) if op not in _PRIMOP_MAP: raise ValueError(self._message(node, self).msg(f"Unsupported operator {op!r}.")) - if self._is_constant_expr(node.operand): + if _is_constant_expr(node.operand): # This function changed the constant node.operand # and returns it. The function calling this one # should intercept this call and replace node @@ -1464,3 +1446,24 @@ def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunct domain = self._this_module.domain self._current_fn = self.ir_builder.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn) + + +def _is_constant_expr(node: ast.AST) -> bool: + """Check if the AST node is a constant expression.""" + if isinstance(node, ast.UnaryOp): + return _is_constant_expr(node.operand) + if isinstance( + node, + ( + ast.Call, + ast.BinOp, + ast.UnaryOp, + ast.Compare, + ast.Attribute, + ast.List, + ast.Load, + ast.Constant, + ), + ): + return all(_is_constant_expr(c) for c in ast.iter_child_nodes(node)) + return False diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..ceebbd807f 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -210,7 +210,7 @@ def _is_optional(type_: type) -> bool: return False -def _get_attr_type(type_: type) -> ir.AttributeType: +def get_attr_type(type_: type) -> ir.AttributeType: """Obtain the type of the attribute from a Python class.""" try: if type_ in _PY_TYPE_TO_ATTR_TYPE: @@ -455,7 +455,7 @@ def from_function( ) else: type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED: # Construct the default attribute if param.default is not inspect.Parameter.empty: # TODO: Use ir_convenience instead to handle int as float diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 8a71b5c2d4..cd8f12a8fa 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -8,6 +8,7 @@ from typing import Optional, Sequence, Union import onnx +import onnx_ir as ir from onnxscript import onnx_types from onnxscript._internal import version_utils @@ -25,35 +26,35 @@ # Map from python type to corresponding ONNX AttributeProto type _PYTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOAT, - int: onnx.AttributeProto.INT, - str: onnx.AttributeProto.STRING, - bool: onnx.AttributeProto.INT, # experimental + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, # experimental } # Map from python type to corresponding ONNX AttributeProto type, # for repeated (i.e., list of) values _LISTTYPE_TO_ATTRTYPE_MAP = { - float: onnx.AttributeProto.FLOATS, - int: onnx.AttributeProto.INTS, - str: onnx.AttributeProto.STRINGS, - bool: onnx.AttributeProto.INTS, # experimental + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, # experimental } _LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) # Map from ONNX AttributeProto type to its representation (in ONNX Script). _ATTRTYPE_TO_REPR = { - onnx.AttributeProto.FLOAT: "float", - onnx.AttributeProto.INT: "int", - onnx.AttributeProto.STRING: "str", - onnx.AttributeProto.FLOATS: "Sequence[float]", - onnx.AttributeProto.INTS: "Sequence[int]", - onnx.AttributeProto.STRINGS: "Sequence[str]", + ir.AttributeType.FLOAT: "float", + ir.AttributeType.INT: "int", + ir.AttributeType.STRING: "str", + ir.AttributeType.FLOATS: "Sequence[float]", + ir.AttributeType.INTS: "Sequence[int]", + ir.AttributeType.STRINGS: "Sequence[str]", } -def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str: +def onnx_attr_type_to_onnxscript_repr(attr_type: ir.AttributeType) -> str: if attr_type not in _ATTRTYPE_TO_REPR: supported = ", ".join( f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR @@ -87,26 +88,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP -def pytype_to_attrtype( - pytype: TypeAnnotationValue, -) -> Optional[onnx.AttributeProto.AttributeType]: - pytype = _remove_annotation(pytype) - if pytype in _PYTYPE_TO_ATTRTYPE_MAP: - return _PYTYPE_TO_ATTRTYPE_MAP[pytype] - type_constructor = typing.get_origin(pytype) - # Remove Optional wrapper if present, which is represented as an Union[..., type(None)] - if type_constructor is typing.Union: - # Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)] - args = [x for x in typing.get_args(pytype) if x is not type(None)] - if len(args) == 1: - return pytype_to_attrtype(args[0]) - if type_constructor in _LIST_CONSTRUCTORS: - elt_type = typing.get_args(pytype)[0] - if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP: - return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type] - return None - - def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: """Returns True if base type of pytype is bool, False otherwise.""" pytype = _remove_annotation(pytype) diff --git a/onnxscript/values.py b/onnxscript/values.py index d957145281..7df8a6ce1c 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -740,13 +740,12 @@ def __init__( """ super().__init__(info) self.value = attr_name - self.typeinfo = typeinfo + if not isinstance(typeinfo, (type, _GenericAlias)): # typing._GenericAlias for List[int] and List[str], etc. raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") self.typeinfo = typeinfo - class DynamicKind(IntFlag): Unknown = 0 Input = 1 From e09fdccf90f687bdf20f9432acb51433569d8171 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Jul 2025 17:50:57 -0700 Subject: [PATCH 07/31] update Signed-off-by: Justin Chu --- onnxscript/_converter.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index cf966ea50f..275560f5e0 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -397,17 +397,11 @@ def _emit_const( # Obtain a name for the constant if suggested_name is None: if isinstance(pyvalue, int): - if pyvalue >= 0: - suggested_name = f"int64_{pyvalue}" - else: - suggested_name = f"int64_m{abs(pyvalue)}" + suggested_name = f"int64_{pyvalue}" elif ( isinstance(pyvalue, list) and len(pyvalue) == 1 and isinstance(pyvalue[0], int) ): - if pyvalue[0] >= 0: - suggested_name = f"int64_{pyvalue[0]}_1d" - else: - suggested_name = f"int64_m{abs(pyvalue[0])}_1d" + suggested_name = f"int64_{pyvalue[0]}_1d" else: suggested_name = "const" var_name = self.generate_unique_name(suggested_name) @@ -424,7 +418,7 @@ def _emit_const( def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) - self.emit([new_var], "Identity", [original_var]) + self.emit("Identity", [original_var], [new_var]) return new_var def _eval_constant_expr(self, expr: ast.AST) -> PyValue: @@ -438,7 +432,7 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: function.) """ # TODO: assert (_is_constant_expr(expr)) - # TODO: Refine types + # TODO(justinchuby): Expand locals? locals: dict[Any, Any] = {} expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) cpl = compile(expr, filename="", mode="eval") From fa60de6557dec658e4e8948a6f232e6bad852394 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Jul 2025 12:30:34 -0700 Subject: [PATCH 08/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 150 ++++++++++++++++++++++++++++++-------- onnxscript/ir/_schemas.py | 1 + 2 files changed, 120 insertions(+), 31 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 275560f5e0..e31e38790a 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -11,6 +11,7 @@ Any, Dict, List, + Mapping, NoReturn, Optional, Sequence, @@ -315,27 +316,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: # ) # return self.ir_builder.make_attr(proto) - def _to_onnx_attr_ref( - self, val: values.AttrRef, info: sourceinfo.SourceInfo | None - ) -> ir.Attr: - """Convert an attribute reference to an ONNX ref attribute.""" - pytype = val.typeinfo - attrtype = _schemas.get_attr_type(pytype) - attrname = None - if attrtype is ir.AttributeType.FLOAT: - attrname = "value_float" - elif attrtype is ir.AttributeType.INT: - attrname = "value_int" - elif attrtype is ir.AttributeType.STRING: - attrname = "value_string" - elif attrtype is ir.AttributeType.INTS: - attrname = "value_ints" - else: - msg = f"Unsupported attribute type {pytype!r}." - fail(info.msg(msg) if info else msg) - # TODO(justinchuby): What is the ref attr name? - return ir.RefAttr(attrname, val.value, attrtype) - def _to_onnx_var( self, val: values.SymbolValue | PyValue, @@ -347,7 +327,7 @@ def _to_onnx_var( if isinstance(val, values.AttrRef): # promote attribute to value result = self.generate_unique_name(target) - attr = self._to_onnx_attr_ref(val, info) + attr = _to_onnx_ref_attr(val, info) self.emit("Constant", [], [result], [attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types @@ -434,6 +414,7 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: # TODO: assert (_is_constant_expr(expr)) # TODO(justinchuby): Expand locals? locals: dict[Any, Any] = {} + # TODO(justinchuby): Find a better way to pass lineno and col_offset expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) cpl = compile(expr, filename="", mode="eval") try: @@ -451,7 +432,7 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: Optional[onnx.defs.OpSchema.Attribute] = None, + attr_meta: Optional[ir.Attr] = None, ) -> Optional[irbuilder.IRAttributeValue]: """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. The following cases are supported: @@ -465,7 +446,7 @@ def _translate_attr( if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): - attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo) + attr_ref = _to_onnx_ref_attr(val, val.typeinfo) if attr_meta is not None and (attr_ref.type != attr_meta.type): self.fail( expr, @@ -793,17 +774,20 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: def _translate_call_expr(self, node: ast.Call): """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) - param_schemas = callee.param_schemas() + op_signature = callee.op_signature # If the callee's schema is available, we use it to determine the inputs and attributes. # Otherwise, we map named arguments to attributes and positional arguments to inputs. - if param_schemas: - kwargs = {x.arg: x.value for x in node.keywords} - args, attrs = param_manipulation.separate_input_attributes_from_arguments( - param_schemas, node.args, kwargs, fill_defaults=False + if op_signature is not None: + args = node.args + kwargs: dict[str, ast.expr] = {x.arg: x.value for x in node.keywords} + # First separate inputs from attributes. This is needed because in Python + # it is possible to pass onnx inputs as kwargs + inputs, attrs = _separate_inputs_and_attrs( + op_signature, args, kwargs ) - args = [self._translate_opt_expr(x) for x in args] + onnx_inputs = [self._translate_opt_expr(x) for x in inputs] attrs = [ - self._translate_attr(x, y, callee.op_schema.attributes[x]) + self._translate_attr(x, y, op_signature.params_map[x]) for x, y in attrs.items() ] else: @@ -1461,3 +1445,107 @@ def _is_constant_expr(node: ast.AST) -> bool: ): return all(_is_constant_expr(c) for c in ast.iter_child_nodes(node)) return False + + + +def _separate_inputs_and_attrs( + signature: _schemas.OpSignature, + args: Sequence[ast.expr], + kwargs: Mapping[str, ast.expr], +) -> tuple[Sequence[ast.expr], dict[str, ast.expr]]: + """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. + + This function uses the OpSignature to determine which argument in args and kwargs corresponds to + which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are + stored in named_attrs. If an _optional input_ is not provided, it is filled with None. + + Args: + signature: The OpSignature for the node. + args: The positional arguments for the node. + kwargs: The keyword arguments for the node. + + Returns: + A tuple of two mappings: named_inputs and named_attrs. + + Raises: + ValueError: If a required parameter is not provided. + """ + # 1. Construct inputs, attrs based on (args, kwargs) and the signature. + # a. Loop over all parameters in the signature and args together + # b. Depending on param.is_input, Record inputs or named_attrs[param.name] = arg + # c. Handle kwargs as well + inputs_reversed: Sequence[Any] = [] + named_attrs: dict[str, Any] = {} + reversed_args_stack = list(reversed(args)) + for param in signature.params: + if isinstance(param, _schemas.Parameter): + # Handle inputs + if reversed_args_stack: + # First exhaust the positional arguments + if param.variadic: + # Handle variadic arguments + inputs_reversed = [*reversed(args)] + reversed_args_stack.clear() + else: + inputs_reversed.append(reversed_args_stack.pop()) + elif param.name in kwargs: + inputs_reversed.append(kwargs[param.name]) + elif param.required: + raise ValueError( + f"Required parameter '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional parameter '%s' is not provided. Added as None. Signature: %s", + param.name, + signature, + ) + inputs_reversed.append(None) + else: + # Handle attributes + attribute: ir.Attr | None + assert isinstance(param, _schemas.AttributeParameter), ( + f"Expected AttributeParameter, got {type(param)}" + ) + if reversed_args_stack: + # First exhaust the positional arguments + attribute = reversed_args_stack.pop() # type: ignore[assignment] + elif kwargs.get(param.name) is not None: + attribute = kwargs[param.name] # type: ignore[assignment] + else: + if param.required: + raise ValueError( + f"Required attribute '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional attribute '%s' is None. Dropped. Signature: %s", + param.name, + signature, + ) + continue + named_attrs[param.name] = attribute + return tuple(reversed(inputs_reversed)), named_attrs + +def _to_onnx_ref_attr( + val: values.AttrRef, info: sourceinfo.SourceInfo | None +) -> ir.Attr: + """Convert an attribute reference to an ONNX ref attribute.""" + pytype = val.typeinfo + attrtype = _schemas.get_attr_type(pytype) + attrname = None + if attrtype is ir.AttributeType.FLOAT: + attrname = "value_float" + elif attrtype is ir.AttributeType.INT: + attrname = "value_int" + elif attrtype is ir.AttributeType.STRING: + attrname = "value_string" + elif attrtype is ir.AttributeType.INTS: + attrname = "value_ints" + else: + msg = f"Unsupported attribute type {pytype!r}." + fail(info.msg(msg) if info else msg) + # TODO(justinchuby): What is the ref attr name? + return ir.RefAttr(attrname, val.value, attrtype) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index ceebbd807f..2a2527e31b 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -3,6 +3,7 @@ from __future__ import annotations import collections.abc +import copy import dataclasses import inspect import logging From 3aff171e37166cc1e226063d5ad263f7f1161484 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Jul 2025 15:27:16 -0700 Subject: [PATCH 09/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index e31e38790a..cad590cb30 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -432,17 +432,18 @@ def _translate_attr( self, attr_name: str, expr: ast.AST, - attr_meta: Optional[ir.Attr] = None, - ) -> Optional[irbuilder.IRAttributeValue]: - """Translate an attribute-value specification of the form `attr_name=` - in a call to an op. expr is an AST. The following cases are supported: + # TODO(justinchuby): Is attr_meta needed? + attr_meta: ir.Attr | None = None, + ) -> ir.Attr | None: + """Translate an attribute-value specification of the form `attr_name=` in a call to an op. expr is an AST. + + The following cases are supported: * Expr evaluates to a script-time constant (a python-value) that can be mapped into an ONNX attribute value, or * Expr evaluates to None, in which case None is returned, or * Expr must be an attribute-reference, that is a name representing an attribute-parameter of a containing function. """ - if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): @@ -453,19 +454,24 @@ def _translate_attr( f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) return attr_ref - if isinstance(val, irbuilder.IRFunction): + if isinstance(val, ir.Function): + # if isinstance(val, irbuilder.IRFunction): # Check that outer-scope variables referenced by function have same value # at function-definition site and use-as-attribute site, to avoid errors. - for pyvar, previous in val.outer_scope_variables: - current = self._lookup(pyvar, self._source_of(expr)) - if current.value != previous.value: - self.fail( - expr, - f"Outer scope variable '{pyvar}' referenced by function " - f"'{expr.id!r}' modified.", - ) + + # TODO(justinchuby): Capture outer_scope_variables + # And implement the following + # for pyvar, previous in val.outer_scope_variables: + # current = self._lookup(pyvar, self._source_of(expr)) + # if current.value != previous.value: + # self.fail( + # expr, + # f"Outer scope variable '{pyvar}' referenced by function " + # f"'{expr.id!r}' modified.", + # ) # Create GraphProto attribute + # TODO: Fix this val = val.to_graph_proto() else: val = self._eval_constant_expr(expr) From 761451a892ea2355004ff7ce7301e25ebd690405 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 11:56:41 -0700 Subject: [PATCH 10/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 83 +++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index cad590cb30..3c97c86019 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -177,7 +177,12 @@ def __init__( # A stack of functions in the outer scope self._outer: list[ir.Function] = [] - self._current_fn: ir.Function | None = None + self._current_fn: ir.Function = ir.Function( + domain=self._this_module.domain, + name="", + graph=ir.Graph((), (), nodes=[]), + attributes={}, + ) self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: list[dict[str, LocalSymValue]] = [{}] @@ -225,13 +230,18 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: def _init_function_translation(self) -> None: """Initialize self for translating a new (top-level) function.""" self._outer = [] - self._current_fn = None + # TODO(justinchuby): Update this + self._current_fn = ir.Function( + domain=self._this_module.domain, + name="", + graph=ir.Graph((), (), nodes=[]), + attributes={}, + ) self._nextvar = 0 self._used_vars = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: - assert self._current_fn is not None return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) def _message(self, node: ast.AST, error_msg: str) -> str: @@ -255,7 +265,6 @@ def _enter_scope(self, name: str, parent_node: ast.AST): """Enter a control-flow block (a loop body or if-then-else branch). The block is translated into a nested-scope in ONNX. """ - assert self._current_fn is not None self._outer.append(self._current_fn) assert self._this_module is not None self._current_fn = ir.Function( @@ -334,7 +343,9 @@ def _to_onnx_var( # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self.generate_unique_name(result + "_as_bool") - self.emit("Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)]) + self.emit( + "Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)] + ) return Variable(result_as_bool, castable=True) return Variable(result, castable=True) @@ -364,7 +375,6 @@ def emit( attributes=attrs, outputs=[self._lookup(out, self._source_of(outputs[0])) for out in outputs], ) - assert self._current_fn is not None self._current_fn.append(node) def _emit_const( @@ -454,12 +464,12 @@ def _translate_attr( f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) return attr_ref - if isinstance(val, ir.Function): - # if isinstance(val, irbuilder.IRFunction): + if isinstance(val, ir.Graph): + # if isinstance(val, irbuilder.IRFunction): # Check that outer-scope variables referenced by function have same value # at function-definition site and use-as-attribute site, to avoid errors. - # TODO(justinchuby): Capture outer_scope_variables + # TODO(justinchuby): Capture outer_scope_variables? # And implement the following # for pyvar, previous in val.outer_scope_variables: # current = self._lookup(pyvar, self._source_of(expr)) @@ -470,9 +480,8 @@ def _translate_attr( # f"'{expr.id!r}' modified.", # ) - # Create GraphProto attribute - # TODO: Fix this - val = val.to_graph_proto() + # Create Graph attribute + pass else: val = self._eval_constant_expr(expr) @@ -482,25 +491,15 @@ def _translate_attr( # The caller is responsible for omitting such attribute-values from the list of attributes # in a NodeProto. if val is None: - if attr_meta and attr_meta.required: - self.fail(expr, f"Attribute '{attr_name}' is required.") return None - attr_type = int(attr_meta.type) if attr_meta else None - attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type) - if attr_meta and (attr.type != attr_meta.type): - self.fail( - expr, - f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'", - ) + attr = ir.convenience.convert_attribute( + attr_name, val, attr_type=attr_meta.type if attr_meta else None + ) return attr - def _translate_docstring(self, node: ast.Expr) -> None: - if hasattr(node.value, "value"): - # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." - ) + def _translate_docstring(self, node: ast.FunctionDef) -> None: + if docstring := ast.get_docstring(node): + self._current_fn.doc_string = docstring def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None @@ -672,9 +671,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # Add to sliced_indices, unless it is "::", which is a no-op. if not (elt.lower is None and elt.upper is None and elt.step is None): sliced_indices.append((axis, elt)) - elif _is_constant_expr(elt) and isinstance( - self._eval_constant_expr(elt), int - ): + elif _is_constant_expr(elt) and isinstance(self._eval_constant_expr(elt), int): scalar_indices.append((axis, elt)) else: non_scalar_indices.append((axis, elt)) @@ -788,9 +785,7 @@ def _translate_call_expr(self, node: ast.Call): kwargs: dict[str, ast.expr] = {x.arg: x.value for x in node.keywords} # First separate inputs from attributes. This is needed because in Python # it is possible to pass onnx inputs as kwargs - inputs, attrs = _separate_inputs_and_attrs( - op_signature, args, kwargs - ) + inputs, attrs = _separate_inputs_and_attrs(op_signature, args, kwargs) onnx_inputs = [self._translate_opt_expr(x) for x in inputs] attrs = [ self._translate_attr(x, y, op_signature.params_map[x]) @@ -944,8 +939,6 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: if isinstance(node, (ast.For, ast.While)): return self._translate_loop_stmt(node) if ast_utils.is_doc_string(node): - if index_of_stmt == 0: - return self._translate_docstring(node) return None if isinstance(node, ast.FunctionDef): return self._translate_nested_function_def(node) @@ -1401,12 +1394,16 @@ def _translate_function_signature_common( return self._current_fn - def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: + def _translate_function_def_common(self, node: ast.FunctionDef) -> ir.Function: """Translate a function definition, including the signature and its body.""" - logger.debug("Converter:_translate_function_def_common:%s", fn.name) - _ = self._translate_function_signature_common(fn) - for i, s in enumerate(fn.body): + logger.debug("Converter:_translate_function_def_common:%s", node.name) + _ = self._translate_function_signature_common(node) + for i, s in enumerate(node.body): self._translate_stmt(s, index_of_stmt=i) + + # Update docstring if available + if docstring := ast.get_docstring(node): + self._current_fn.doc_string = docstring return self._current_fn def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: @@ -1453,7 +1450,6 @@ def _is_constant_expr(node: ast.AST) -> bool: return False - def _separate_inputs_and_attrs( signature: _schemas.OpSignature, args: Sequence[ast.expr], @@ -1535,9 +1531,8 @@ def _separate_inputs_and_attrs( named_attrs[param.name] = attribute return tuple(reversed(inputs_reversed)), named_attrs -def _to_onnx_ref_attr( - val: values.AttrRef, info: sourceinfo.SourceInfo | None -) -> ir.Attr: + +def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr: """Convert an attribute reference to an ONNX ref attribute.""" pytype = val.typeinfo attrtype = _schemas.get_attr_type(pytype) From 882af66d6323d9b10e227caff0f21dfd8f2d0b62 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 12:07:16 -0700 Subject: [PATCH 11/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 60 +++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 32 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 3c97c86019..3c8759760b 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -140,32 +140,41 @@ def __str__(self) -> str: class Converter: """Main class to translate python code into ONNX operators. - The class uses logger `onnxscript`. Logging can be enabled with the following code: + The converter translates a Python function into an ONNX function by + traversing the Python AST of the function and generating ONNX nodes + that represent the operations in the Python code. - :: + ..tip:: - import logging - logging.basicConfig(level=logging.DEBUG) + The class uses logger `onnxscript`. Logging can be enabled with the following code: - Or if you need to enable only the logger used by this module: + :: + + import logging + logging.basicConfig(level=logging.DEBUG) + + Or if you need to enable only the logger used by this module: - :: + :: - import logging - logger = logging.getLogger('onnxscript') - logger.setLevel(logging.DEBUG) - console = logging.StreamHandler() - logger.addHandler(console) + import logging + logger = logging.getLogger('onnxscript') + logger.setLevel(logging.DEBUG) + console = logging.StreamHandler() + logger.addHandler(console) """ def __init__( self, + root: ast.FunctionDef, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): self._source = source + self._root = root + if global_names is not None: # We make a copy in case function eval modifies it. self._globals = global_names.copy() @@ -313,18 +322,6 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - # def _make_onnx_attr( - # self, attrname: str, attrval: Any, attrtype: int | None = None - # ) -> irbuilder.IRAttributeValue: - # def tensor_name_generator() -> str: - # """Return name to be used for tensor, if we need to create one.""" - # return self.generate_unique_name(f"attr_{attrname}") - - # proto = autocast.pyvalue_to_onnx_attribute( - # attrname, attrval, tensor_name_generator, attrtype - # ) - # return self.ir_builder.make_attr(proto) - def _to_onnx_var( self, val: values.SymbolValue | PyValue, @@ -497,10 +494,6 @@ def _translate_attr( ) return attr - def _translate_docstring(self, node: ast.FunctionDef) -> None: - if docstring := ast.get_docstring(node): - self._current_fn.doc_string = docstring - def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None ) -> Variable: @@ -1323,7 +1316,7 @@ def _translate_block( def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" self._enter_scope(fn.name, fn) - self._translate_function_def_common(fn) + self._translate_function_def(fn) function_ir = self._exit_scope() outer_scope_vars = analysis.outer_scope_variables(fn, self._message) function_ir.outer_scope_variables = [ @@ -1394,16 +1387,16 @@ def _translate_function_signature_common( return self._current_fn - def _translate_function_def_common(self, node: ast.FunctionDef) -> ir.Function: + def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function: """Translate a function definition, including the signature and its body.""" - logger.debug("Converter:_translate_function_def_common:%s", node.name) + logger.debug("Converter:_translate_function_def:%s", node.name) _ = self._translate_function_signature_common(node) for i, s in enumerate(node.body): self._translate_stmt(s, index_of_stmt=i) # Update docstring if available if docstring := ast.get_docstring(node): - self._current_fn.doc_string = docstring + self._current_fn.doc_string = docstring return self._current_fn def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: @@ -1416,7 +1409,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: domain = self._this_module.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) analysis.do_liveness_analysis(stmt, self._message) - fn_ir = self._translate_function_def_common(stmt) + fn_ir = self._translate_function_def(stmt) fn_ir.debug_print() self._this_module.add_function_def(fn_ir) return fn_ir @@ -1424,6 +1417,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" + assert self._this_module is not None domain = self._this_module.domain self._current_fn = self.ir_builder.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn) @@ -1534,6 +1528,8 @@ def _separate_inputs_and_attrs( def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr: """Convert an attribute reference to an ONNX ref attribute.""" + + # TODO(justinchuby): Consider using a convenience function pytype = val.typeinfo attrtype = _schemas.get_attr_type(pytype) attrname = None From 852cc422aedaaefd1b4fed1b7c7f36ff35ea5eb7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 16:29:20 -0700 Subject: [PATCH 12/31] update Signed-off-by: Justin Chu --- onnxscript/_converter.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 3c8759760b..c33160d601 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -167,27 +167,42 @@ class Converter: def __init__( self, root: ast.FunctionDef, + *, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self._source = source + """Initialize the converter. + + Args: + root: The root AST node of the function to be converted. + opset: The ONNX opset to use for the conversion. If None, the default opset is used. + global_names: A dictionary of global names available in the script. + source: Optional source code string for error reporting. + default_opset: The default ONNX opset to use if no ONNX opset is specified in the script. + """ + self._root = root + self._opset = opset if global_names is not None: # We make a copy in case function eval modifies it. self._globals = global_names.copy() - self._this_module = opset + else: + self._globals = {} + + self._source = source self._default_opset = default_opset # TODO(justinchuby): Update ir version to be user defined + # TODO(justinchuby): Maybe just store a list of functions self._model = ir.Model(ir.Graph((), (), nodes=()), ir_version=10) # A stack of functions in the outer scope self._outer: list[ir.Function] = [] self._current_fn: ir.Function = ir.Function( - domain=self._this_module.domain, + domain=self._opset.domain, name="", graph=ir.Graph((), (), nodes=[]), attributes={}, @@ -241,7 +256,7 @@ def _init_function_translation(self) -> None: self._outer = [] # TODO(justinchuby): Update this self._current_fn = ir.Function( - domain=self._this_module.domain, + domain=self._opset.domain, name="", graph=ir.Graph((), (), nodes=[]), attributes={}, @@ -275,9 +290,9 @@ def _enter_scope(self, name: str, parent_node: ast.AST): The block is translated into a nested-scope in ONNX. """ self._outer.append(self._current_fn) - assert self._this_module is not None + assert self._opset is not None self._current_fn = ir.Function( - domain=self._this_module.domain, + domain=self._opset.domain, name=name, graph=ir.Graph((), (), nodes=[]), attributes={}, @@ -1406,19 +1421,19 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: opset = self._find_onnx_opset(stmt) if opset: self._set_default_opset(opset, stmt) - domain = self._this_module.domain + domain = self._opset.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) analysis.do_liveness_analysis(stmt, self._message) fn_ir = self._translate_function_def(stmt) fn_ir.debug_print() - self._this_module.add_function_def(fn_ir) + self._opset.add_function_def(fn_ir) return fn_ir raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" - assert self._this_module is not None - domain = self._this_module.domain + assert self._opset is not None + domain = self._opset.domain self._current_fn = self.ir_builder.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn) From be610e9988e51d1a583e0d3fe7962dcd4a52bad2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 16:40:16 -0700 Subject: [PATCH 13/31] continue Signed-off-by: Justin Chu --- onnxscript/_converter.py | 76 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index c33160d601..669a232e6e 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -229,7 +229,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: or opset.version != self._default_opset.version ): self.fail( - node, f"Two distincts opset were used ({opset} != {self._default_opset})." + node, f"Two distinct opset were used ({opset} != {self._default_opset})." ) else: self._default_opset = opset @@ -251,19 +251,19 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: return res return None - def _init_function_translation(self) -> None: - """Initialize self for translating a new (top-level) function.""" - self._outer = [] - # TODO(justinchuby): Update this - self._current_fn = ir.Function( - domain=self._opset.domain, - name="", - graph=ir.Graph((), (), nodes=[]), - attributes={}, - ) - self._nextvar = 0 - self._used_vars = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + # def _init_function_translation(self) -> None: + # """Initialize self for translating a new (top-level) function.""" + # self._outer = [] + # # TODO(justinchuby): Update this + # self._current_fn = ir.Function( + # domain=self._opset.domain, + # name="", + # graph=ir.Graph((), (), nodes=[]), + # attributes={}, + # ) + # self._nextvar = 0 + # self._used_vars = set() + # self._locals: List[Dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) @@ -328,7 +328,7 @@ def _lookup( raise ValueError(info.msg(f"Unbound name: {name}.")) return None - def generate_unique_name(self, candidate: str = "tmp") -> str: + def _generate_unique_name(self, candidate: str = "tmp") -> str: # TODO(justinchuby): Can we reduce the O complexity of this function? r = candidate while r in self._used_vars: @@ -347,14 +347,14 @@ def _to_onnx_var( """Convert a value to an ONNX variable.""" if isinstance(val, values.AttrRef): # promote attribute to value - result = self.generate_unique_name(target) + result = self._generate_unique_name(target) attr = _to_onnx_ref_attr(val, info) self.emit("Constant", [], [result], [attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self.generate_unique_name(result + "_as_bool") + result_as_bool = self._generate_unique_name(result + "_as_bool") self.emit( "Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)] ) @@ -406,7 +406,7 @@ def _emit_const( suggested_name = f"int64_{pyvalue[0]}_1d" else: suggested_name = "const" - var_name = self.generate_unique_name(suggested_name) + var_name = self._generate_unique_name(suggested_name) # Create a tensor from the python value try: @@ -419,7 +419,7 @@ def _emit_const( def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" - new_var = self.generate_unique_name(suggested_name) + new_var = self._generate_unique_name(suggested_name) self.emit("Identity", [original_var], [new_var]) return new_var @@ -539,7 +539,7 @@ def _translate_expr( callee, args, attrs = r target = "tmp" if target is None else target assert isinstance(target, str) - result = self.generate_unique_name(target) + result = self._generate_unique_name(target) self.emit([result], callee, args, attrs) return Variable(result) @@ -594,7 +594,7 @@ def _translate_subscript_expr( var_name = var.name if target is None: target = f"{var_name}_subscripted" - target = self.generate_unique_name(target) + target = self._generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) info = self._source_of(node.slice) @@ -635,7 +635,7 @@ def translate_slice_component( raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: name = self._translate_expr(node_arg).name - reshaped = self.generate_unique_name(f"{name}_reshaped") + reshaped = self._generate_unique_name(f"{name}_reshaped") self.emit( [reshaped], values.Op(self.default_opset, "Reshape"), @@ -721,16 +721,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: if len(starts) > 1: axis_0_attr = self._make_onnx_attr("axis", 0) - start_name = self.generate_unique_name(f"{var_name}_start") + start_name = self._generate_unique_name(f"{var_name}_start") self.emit([start_name], "Concat", starts, [axis_0_attr]) - end_name = self.generate_unique_name(f"{var_name}_end") + end_name = self._generate_unique_name(f"{var_name}_end") self.emit([end_name], "Concat", ends, [axis_0_attr]) - axes_name = self.generate_unique_name(f"{var_name}_axis") + axes_name = self._generate_unique_name(f"{var_name}_axis") self.emit([axes_name], "Concat", axes, [axis_0_attr]) - steps_name = self.generate_unique_name(f"{var_name}_step") + steps_name = self._generate_unique_name(f"{var_name}_step") self.emit([steps_name], "Concat", steps, [axis_0_attr]) else: start_name = starts[0] @@ -739,7 +739,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: steps_name = steps[0] if squeezed_axes: - sliced_name = self.generate_unique_name(f"{var_name}_sliced") + sliced_name = self._generate_unique_name(f"{var_name}_sliced") self.emit( [sliced_name], "Slice", @@ -748,14 +748,14 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze - result = self.generate_unique_name(f"{var_name}_squeezed") + result = self._generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target result = target self.emit([result], "Squeeze", [sliced_name, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice - result = self.generate_unique_name(f"{var_name}_sliced") + result = self._generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target result = target slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] @@ -774,7 +774,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather - gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") + gathered = self._generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) @@ -876,7 +876,7 @@ def _translate_compare_expr(self, node): op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal") left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": - tmp = self.generate_unique_name() + tmp = self._generate_unique_name() self.emit([tmp], op, [left, right]) not_op = values.Op(self.default_opset, "Not") return not_op, [tmp], [] @@ -979,7 +979,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") - onnx_name = self.generate_unique_name(x.id) + onnx_name = self._generate_unique_name(x.id) self._bind( x.id, values.Dynamic( @@ -1078,7 +1078,7 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: elseAttr = self._make_onnx_attr("else_branch", elseGraph) def rename(x): - r = self.generate_unique_name(x) + r = self._generate_unique_name(x) self._bind( x, values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)), @@ -1122,7 +1122,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name - o_cond_var = self.generate_unique_name("cond_in") + o_cond_var = self._generate_unique_name("cond_in") i_cond_var = o_cond_var cond_while = None o_loop_condition = "" # No condition for a for loop. @@ -1156,7 +1156,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # build loop_body self._enter_scope("loop_body", loop_stmt) - o_loop_var = self.generate_unique_name(p_loop_var) + o_loop_var = self._generate_unique_name(p_loop_var) self.ir_builder.add_input( self._current_fn, o_loop_var, @@ -1176,7 +1176,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) for pv in loop_state_vars: - ov = self.generate_unique_name(pv) + ov = self._generate_unique_name(pv) # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None @@ -1217,7 +1217,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: continue self._translate_stmt(s) - o_cond_out = self.generate_unique_name("cond_out") + o_cond_out = self._generate_unique_name("cond_out") if cond_while is not None: # Loop while @@ -1267,7 +1267,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: info = self._source_of(loop_stmt) def rename(x): - r = self.generate_unique_name(x) + r = self._generate_unique_name(x) self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info)) return r From 98754ab46e2c263b1d9399ab82ae7fa4bab1119a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 17:14:34 -0700 Subject: [PATCH 14/31] Update analysis pass Signed-off-by: Justin Chu --- onnxscript/_converter.py | 155 +++++++++--------- .../_internal/{analysis.py => _analysis.py} | 43 +++-- .../{analysis_test.py => _analysis_test.py} | 8 +- 3 files changed, 105 insertions(+), 101 deletions(-) rename onnxscript/_internal/{analysis.py => _analysis.py} (87%) rename onnxscript/_internal/{analysis_test.py => _analysis_test.py} (95%) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 669a232e6e..20c6ac952e 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -5,6 +5,8 @@ from __future__ import annotations import ast +from collections import defaultdict +import dataclasses import logging from typing import ( TYPE_CHECKING, @@ -26,7 +28,7 @@ import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values from onnxscript import type_annotation as ta -from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation +from onnxscript._internal import _analysis, ast_utils, autocast, param_manipulation if TYPE_CHECKING: # The type-alias LocalSymValue represents the types of values that local names in a @@ -137,6 +139,18 @@ def __str__(self) -> str: return self.name +@dataclasses.dataclass +class ASTMeta: + """Metadata for an AST node. + + This class is used to store metadata about an AST node. + """ + + # For liveness analysis, + live_out: set[ast.AST] = dataclasses.field(default_factory=set) + live_in: set[ast.AST] = dataclasses.field(default_factory=set) + + class Converter: """Main class to translate python code into ONNX operators. @@ -182,8 +196,11 @@ def __init__( source: Optional source code string for error reporting. default_opset: The default ONNX opset to use if no ONNX opset is specified in the script. """ - - self._root = root + if not isinstance(root, ast.FunctionDef): + raise TypeError( + f"Converter expects an AST FunctionDef node, got {type(root)}." + ) + self._ast_root = root self._opset = opset if global_names is not None: @@ -193,7 +210,12 @@ def __init__( self._globals = {} self._source = source - self._default_opset = default_opset + self._default_opset = default_opset or _find_onnx_opset(root, self._globals) + if self._default_opset is None: + raise ValueError( + "default_opset must be specified in script for functions " + "that do not contain any use of an ONNX opset." + ) # TODO(justinchuby): Update ir version to be user defined # TODO(justinchuby): Maybe just store a list of functions @@ -210,46 +232,8 @@ def __init__( self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: list[dict[str, LocalSymValue]] = [{}] - - @property - def default_opset(self) -> values.Opset: - if self._default_opset is None: - raise RuntimeError( - "default_opset must be specified in script for functions " - "that do not contain any use of an ONNX opset." - ) - return self._default_opset - - def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: - if opset.domain != "": - return - if self._default_opset is not None: - if ( - opset.domain != self._default_opset.domain - or opset.version != self._default_opset.version - ): - self.fail( - node, f"Two distinct opset were used ({opset} != {self._default_opset})." - ) - else: - self._default_opset = opset - - def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]: - """Find the (first) ONNX opset used in the function, if any.""" - # Search for a Call expression of form "op.OpName(...)" - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Attribute): - opset_expr = node.func.value - if isinstance(opset_expr, ast.Name): - if opset_expr.id in self._globals: - opset = self._globals[opset_expr.id] - if isinstance(opset, values.Opset) and opset.domain == "": - return opset - for child in ast.iter_child_nodes(node): - res = self._find_onnx_opset(child) - if res is not None: - return res - return None + self._finalized = False + self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta) # def _init_function_translation(self) -> None: # """Initialize self for translating a new (top-level) function.""" @@ -638,7 +622,7 @@ def translate_slice_component( reshaped = self._generate_unique_name(f"{name}_reshaped") self.emit( [reshaped], - values.Op(self.default_opset, "Reshape"), + values.Op(self._default_opset, "Reshape"), [name, one_1d().name], [], ) @@ -827,7 +811,7 @@ def _translate_binary_op_expr(self, node: ast.BinOp): if isinstance(cst, float): attr = [self._make_onnx_attr("fmod", 1)] - op = values.Op(self.default_opset, _PRIMOP_MAP[op]) + op = values.Op(self._default_opset, _PRIMOP_MAP[op]) left, right = self._cast_like_binary_expression( op, self._translate_expr(node.left), self._translate_expr(node.right) ) @@ -858,7 +842,7 @@ def _translate_unary_op_expr(self, node): return self._translate_expr(node.operand) opname = _PRIMOP_MAP[op] operand = self._translate_expr(node.operand) - return values.Op(self.default_opset, opname), [operand], [] + return values.Op(self._default_opset, opname), [operand], [] def _translate_compare_expr(self, node): # TODO: handle multiple comparisons in one expression @@ -873,12 +857,12 @@ def _translate_compare_expr(self, node): # NotEqual is not a standard ONNX op, and needs to be translated into # an Equal op/node followed by a Not op/node. - op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal") + op = values.Op(self._default_opset, opname if opname != "NotEqual" else "Equal") left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": tmp = self._generate_unique_name() self.emit([tmp], op, [left, right]) - not_op = values.Op(self.default_opset, "Not") + not_op = values.Op(self._default_opset, "Not") return not_op, [tmp], [] return op, [left, right], [] @@ -918,12 +902,12 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable if isinstance(found, values.Op): return found if not found: - if function_name not in self.default_opset: + if function_name not in self._default_opset: warn( f"Unknown function name {function_name!r}. " f"The ONNX graph may not work." ) - return values.Op(self.default_opset, function_name) + return values.Op(self._default_opset, function_name) self.fail(node, "Invalid callee") def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None: @@ -1062,10 +1046,10 @@ def ret(exp, i, suffix): def _translate_if_stmt(self, stmt: ast.If) -> None: if hasattr(stmt, "live_out"): live_defs = list( - stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message)) + stmt.live_out.intersection(_analysis.assigned_vars(stmt, self._message)) ) else: - live_defs = list(analysis.assigned_vars(stmt, self._message)) + live_defs = list(_analysis.assigned_vars(stmt, self._message)) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno thenGraph, sub_fct_then = self._translate_block( @@ -1097,7 +1081,7 @@ def rename(x): self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( renamed, - values.Op(self.default_opset, "If"), + values.Op(self._default_opset, "If"), [test], [thenAttr, elseAttr], sub_functions=sub_functions, @@ -1145,8 +1129,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: else: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body - exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message) - vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message) + exposed_uses = _analysis.exposed_uses(loop_stmt.body, self._message) + vars_def_in_loop = _analysis.assigned_vars(loop_stmt.body, self._message) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) @@ -1232,7 +1216,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.emit( [o_cond_out], - values.Op(self.default_opset, operator_name), + values.Op(self._default_opset, operator_name), [condition_name or o_cond_var], [], ) @@ -1333,7 +1317,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: self._enter_scope(fn.name, fn) self._translate_function_def(fn) function_ir = self._exit_scope() - outer_scope_vars = analysis.outer_scope_variables(fn, self._message) + outer_scope_vars = _analysis.outer_scope_variables(fn, self._message) function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] @@ -1343,7 +1327,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: def _translate_function_signature_common( self, fn: ast.FunctionDef - ) -> irbuilder.IRFunction: + ) -> ir.Function: """Translate a function signature (top-level or nested).""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: @@ -1414,28 +1398,19 @@ def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function: self._current_fn.doc_string = docstring return self._current_fn - def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: - if isinstance(stmt, ast.FunctionDef): - self._init_function_translation() - if self._default_opset is None: - opset = self._find_onnx_opset(stmt) - if opset: - self._set_default_opset(opset, stmt) - domain = self._opset.domain - self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) - analysis.do_liveness_analysis(stmt, self._message) - fn_ir = self._translate_function_def(stmt) - fn_ir.debug_print() - self._opset.add_function_def(fn_ir) - return fn_ir - raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") - - def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: - """Translate a (top-level) function signature.""" - assert self._opset is not None - domain = self._opset.domain - self._current_fn = self.ir_builder.new_function(fn.name, domain, True) - return self._translate_function_signature_common(fn) + def _finalize(self) -> None: + self._finalized = True + + def convert(self) -> ir.Function: + """Convert the Python AST to an ONNX IR function.""" + if self._finalized: + return self._current_fn + + func_def = self._ast_root + _analysis.do_liveness_analysis(func_def, self._message, self.meta) + return self._translate_function_def(func_def) + # TODO(justinchuby): Handle function registration to the opset + # self._opset.add_function_def(fn_ir) def _is_constant_expr(node: ast.AST) -> bool: @@ -1561,3 +1536,21 @@ def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) - fail(info.msg(msg) if info else msg) # TODO(justinchuby): What is the ref attr name? return ir.RefAttr(attrname, val.value, attrtype) + + +def _find_onnx_opset(node: ast.AST, globals: dict[str, Any]) -> values.Opset | None: + """Find the (first) ONNX opset used in the function, if any.""" + # Search for a Call expression of form "op.OpName(...)" + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + opset_expr = node.func.value + if isinstance(opset_expr, ast.Name): + if opset_expr.id in globals: + opset = globals[opset_expr.id] + if isinstance(opset, values.Opset) and opset.domain == "": + return opset + for child in ast.iter_child_nodes(node): + res = _find_onnx_opset(child, globals) + if res is not None: + return res + return None diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/_analysis.py similarity index 87% rename from onnxscript/_internal/analysis.py rename to onnxscript/_internal/_analysis.py index 0403f60c91..50e1cba1bd 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/_analysis.py @@ -1,13 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Analysis utilities for Python AST.""" from __future__ import annotations import ast -from typing import Any, Optional, Sequence, Set +from typing import Any, Optional, Sequence, TYPE_CHECKING +from collections import defaultdict from onnxscript import sourceinfo from onnxscript._internal import ast_utils +if TYPE_CHECKING: + from onnxscript import _converter + def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: if not isinstance(for_stmt.target, ast.Name): @@ -15,7 +20,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: return for_stmt.target.id -def _used_vars(expr: Optional[ast.expr]) -> Set[str]: +def _used_vars(expr: Optional[ast.expr]) -> set[str]: """Return set of all variables used, including function names, in an expression.""" if expr is None: return set() @@ -35,7 +40,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def _lhs_vars(lhs: ast.expr) -> Set[str]: +def _lhs_vars(lhs: ast.expr) -> set[str]: """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): @@ -49,12 +54,12 @@ def get_id(e): def assigned_vars( stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: +) -> set[str]: """Return the set of all variables that may be assigned to in an execution of input stmt or sequence of statements. """ - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]: result: set[Any] = set() for s in block: result = result | assigned_vars(s, formatter) @@ -84,20 +89,26 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: raise ValueError(error_message) -def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Perform liveness analysis of the given function-ast. The results of the - analysis are stored directly with each statement-ast `s` as attributes `s.live_in` - and `s.live_out`. +def do_liveness_analysis( + fun: ast.FunctionDef, + formatter: sourceinfo.Formatter, + meta: defaultdict[ast.AST, _converter.ASTMeta], +): + """Perform liveness analysis of the given function-ast. + + The results of the analysis are stored in the `meta` dictionary, which maps + each AST node to its metadata. The metadata includes the set of live variables + at the entry and exit of each node. """ - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - stmt.live_out = live_out # type: ignore[attr-defined] + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + meta[stmt].live_out = live_out live = do_visit(stmt, live_out) - stmt.live_in = live # type: ignore[attr-defined] + meta[stmt].live_in = live return live - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for s in reversed(block): live_out = visit(s, live_out) return live_out @@ -165,12 +176,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): (in the first statement). Hence x is included in the exposed_uses. """ - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for stmt in reversed(block): live_out = visit(stmt, live_out) return live_out - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: if isinstance(stmt, ast.Assign): return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/_analysis_test.py similarity index 95% rename from onnxscript/_internal/analysis_test.py rename to onnxscript/_internal/_analysis_test.py index 74e7ca4c18..64e9b5b110 100644 --- a/onnxscript/_internal/analysis_test.py +++ b/onnxscript/_internal/_analysis_test.py @@ -6,7 +6,7 @@ import unittest from typing import Any -from onnxscript._internal import analysis, ast_utils +from onnxscript._internal import _analysis, ast_utils from onnxscript.onnx_opset import opset15 as op from onnxscript.sourceinfo import formatter @@ -30,7 +30,7 @@ def generic_visit(self, node): class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): source, parse_tree = ast_utils.get_src_and_ast(fun) - analysis.do_liveness_analysis(parse_tree, formatter(source)) + _analysis.do_liveness_analysis(parse_tree, formatter(source)) visitor = AnalysisResultsVisitor() visitor.visit(parse_tree) return visitor.results @@ -113,7 +113,7 @@ def while_eg(x): class TestExposedUses(unittest.TestCase): def assertUses(self, f, expected): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.exposed_uses(parse_tree.body, formatter(source)) + result = _analysis.exposed_uses(parse_tree.body, formatter(source)) self.assertEqual(result, set(expected)) def test_basic(self): @@ -190,7 +190,7 @@ def f(x): class TestAssignedVarAnalysis(unittest.TestCase): def assert_assigned_vars(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.assigned_vars(parse_tree.body, formatter(source)) + result = _analysis.assigned_vars(parse_tree.body, formatter(source)) self.assertEqual(result, expected) def test_basic_defs(self): From 22eeddc514a0a129b76a619bd3203a35f077bcd6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 17:19:47 -0700 Subject: [PATCH 15/31] live_out Signed-off-by: Justin Chu --- onnxscript/_converter.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 20c6ac952e..d1e63fc707 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -147,8 +147,8 @@ class ASTMeta: """ # For liveness analysis, - live_out: set[ast.AST] = dataclasses.field(default_factory=set) - live_in: set[ast.AST] = dataclasses.field(default_factory=set) + live_out: set[str] | None = None + live_in: set[str] | None = None class Converter: @@ -1044,9 +1044,9 @@ def ret(exp, i, suffix): return ret(val, 0, "") def _translate_if_stmt(self, stmt: ast.If) -> None: - if hasattr(stmt, "live_out"): + if (live_out := self.meta[stmt].live_out) is not None: live_defs = list( - stmt.live_out.intersection(_analysis.assigned_vars(stmt, self._message)) + live_out.intersection(_analysis.assigned_vars(stmt, self._message)) ) else: live_defs = list(_analysis.assigned_vars(stmt, self._message)) @@ -1131,7 +1131,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # analyze loop body exposed_uses = _analysis.exposed_uses(loop_stmt.body, self._message) vars_def_in_loop = _analysis.assigned_vars(loop_stmt.body, self._message) - loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) + live_out = self.meta[loop_stmt].live_out or set() + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) From c74b854d3647bf582f6ca3d83f2e04c59eca3b87 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 17:49:43 -0700 Subject: [PATCH 16/31] Update if Signed-off-by: Justin Chu --- onnxscript/_converter.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index d1e63fc707..d625f037db 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -1052,14 +1052,16 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: live_defs = list(_analysis.assigned_vars(stmt, self._message)) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno - thenGraph, sub_fct_then = self._translate_block( - stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt + + # TODO(justinchuby): Ensure the values are obtained from the live_defs + then_graph, sub_fct_then = self._translate_block( + stmt.body, f"then_graph_{lineno}", live_defs, parent_stmt=stmt ) - thenAttr = self._make_onnx_attr("then_branch", thenGraph) - elseGraph, sub_fct_else = self._translate_block( - stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt + then_attr = ir.AttrGraph("then_branch", then_graph) + else_graph, sub_fct_else = self._translate_block( + stmt.orelse, f"else_graph_{lineno}", live_defs, parent_stmt=stmt ) - elseAttr = self._make_onnx_attr("else_branch", elseGraph) + else_attr = ir.AttrGraph("else_branch", else_graph) def rename(x): r = self._generate_unique_name(x) @@ -1072,19 +1074,20 @@ def rename(x): # no break condition renamed = [rename(x) for x in live_defs] if not renamed: - self.fail(stmt, "A subgraph for a test do not have any output variable.") + # TODO(justinchuby): This needs comments. What is it doing? + self.fail(stmt, "A subgraph for an if condition has no outputs.") + # TODO(justinchuby): Collect the subfunctions to self sub_functions = {} sub_functions.update(sub_fct_then) sub_functions.update(sub_fct_else) if renamed == [test]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( - renamed, - values.Op(self._default_opset, "If"), + "If", [test], - [thenAttr, elseAttr], - sub_functions=sub_functions, + renamed, + [then_attr, else_attr], ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: From 2c43e9aefdae34bbec713c766ce6720007358a16 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 18 Jul 2025 17:50:00 -0700 Subject: [PATCH 17/31] fmt Signed-off-by: Justin Chu --- onnxscript/_converter.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index d625f037db..e4f2a6c773 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -197,9 +197,7 @@ def __init__( default_opset: The default ONNX opset to use if no ONNX opset is specified in the script. """ if not isinstance(root, ast.FunctionDef): - raise TypeError( - f"Converter expects an AST FunctionDef node, got {type(root)}." - ) + raise TypeError(f"Converter expects an AST FunctionDef node, got {type(root)}.") self._ast_root = root self._opset = opset @@ -1329,9 +1327,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: # TODO: Does not yet handle nested functions within nested functions. self._current_fn.add_nested_function(function_ir) - def _translate_function_signature_common( - self, fn: ast.FunctionDef - ) -> ir.Function: + def _translate_function_signature_common(self, fn: ast.FunctionDef) -> ir.Function: """Translate a function signature (top-level or nested).""" args = fn.args if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg: From c53b8f3c4a1b38f0808b9faf59ed4d0d20235766 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Jul 2025 11:19:06 -0700 Subject: [PATCH 18/31] Update emit ordering Signed-off-by: Justin Chu --- onnxscript/_converter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index e4f2a6c773..0cb2e74a39 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -331,14 +331,14 @@ def _to_onnx_var( # promote attribute to value result = self._generate_unique_name(target) attr = _to_onnx_ref_attr(val, info) - self.emit("Constant", [], [result], [attr]) + self.emit([], "Constant", [result], [attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self._generate_unique_name(result + "_as_bool") self.emit( - "Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)] + [result], "Cast", [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)] ) return Variable(result_as_bool, castable=True) return Variable(result, castable=True) @@ -355,9 +355,9 @@ def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Varia def emit( self, + outputs: Sequence[str], op_type: str, inputs: Sequence[str], - outputs: Sequence[str], attrs: Sequence[ir.Attr] = (), domain: str = "", ): @@ -396,13 +396,13 @@ def _emit_const( except Exception as e: fail(info.msg(str(e))) - self.emit("Constant", [], [var_name], [ir.AttrTensor("value", tensor)]) + self.emit([], "Constant", [var_name], [ir.AttrTensor("value", tensor)]) return Variable(var_name, True) def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self._generate_unique_name(suggested_name) - self.emit("Identity", [original_var], [new_var]) + self.emit([original_var], "Identity", [new_var]) return new_var def _eval_constant_expr(self, expr: ast.AST) -> PyValue: @@ -1082,8 +1082,8 @@ def rename(x): if renamed == [test]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( - "If", [test], + "If", renamed, [then_attr, else_attr], ) From c84ad91b834d71cfd14451087f561a4c2c884f6e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Jul 2025 11:21:31 -0700 Subject: [PATCH 19/31] update emit call Signed-off-by: Justin Chu --- onnxscript/_converter.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 0cb2e74a39..96d3e55046 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -331,14 +331,17 @@ def _to_onnx_var( # promote attribute to value result = self._generate_unique_name(target) attr = _to_onnx_ref_attr(val, info) - self.emit([], "Constant", [result], [attr]) + self.emit([], "Constant", [result], attrs=[attr]) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self._generate_unique_name(result + "_as_bool") self.emit( - [result], "Cast", [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)] + [result], + "Cast", + [result_as_bool], + attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], ) return Variable(result_as_bool, castable=True) return Variable(result, castable=True) @@ -358,6 +361,7 @@ def emit( outputs: Sequence[str], op_type: str, inputs: Sequence[str], + *, attrs: Sequence[ir.Attr] = (), domain: str = "", ): @@ -396,7 +400,7 @@ def _emit_const( except Exception as e: fail(info.msg(str(e))) - self.emit([], "Constant", [var_name], [ir.AttrTensor("value", tensor)]) + self.emit([], "Constant", [var_name], attrs=[ir.AttrTensor("value", tensor)]) return Variable(var_name, True) def _emit_copy(self, original_var: str, suggested_name: str) -> str: @@ -522,7 +526,7 @@ def _translate_expr( target = "tmp" if target is None else target assert isinstance(target, str) result = self._generate_unique_name(target) - self.emit([result], callee, args, attrs) + self.emit([result], callee, args, attrs=attrs) return Variable(result) def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: @@ -620,9 +624,8 @@ def translate_slice_component( reshaped = self._generate_unique_name(f"{name}_reshaped") self.emit( [reshaped], - values.Op(self._default_opset, "Reshape"), + "Reshape", [name, one_1d().name], - [], ) return reshaped, None @@ -704,16 +707,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: if len(starts) > 1: axis_0_attr = self._make_onnx_attr("axis", 0) start_name = self._generate_unique_name(f"{var_name}_start") - self.emit([start_name], "Concat", starts, [axis_0_attr]) + self.emit([start_name], "Concat", starts, attrs=[axis_0_attr]) end_name = self._generate_unique_name(f"{var_name}_end") - self.emit([end_name], "Concat", ends, [axis_0_attr]) + self.emit([end_name], "Concat", ends, attrs=[axis_0_attr]) axes_name = self._generate_unique_name(f"{var_name}_axis") - self.emit([axes_name], "Concat", axes, [axis_0_attr]) + self.emit([axes_name], "Concat", axes, attrs=[axis_0_attr]) steps_name = self._generate_unique_name(f"{var_name}_step") - self.emit([steps_name], "Concat", steps, [axis_0_attr]) + self.emit([steps_name], "Concat", steps, attrs=[axis_0_attr]) else: start_name = starts[0] end_name = ends[0] @@ -759,7 +762,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: gathered = self._generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) + self.emit([gathered], "Gather", [str(result), index_value], attrs=[axis_attr]) result = gathered return Variable(result) @@ -971,7 +974,7 @@ def generate_onnx_name(x: ast.AST): return onnx_name outputs = [generate_onnx_name(x) for x in lhs.elts] - self.emit(outputs, callee, inputs, attrs) + self.emit(outputs, callee, inputs, attrs=attrs) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") @@ -1085,7 +1088,7 @@ def rename(x): [test], "If", renamed, - [then_attr, else_attr], + attrs=[then_attr, else_attr], ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: @@ -1218,9 +1221,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.emit( [o_cond_out], - values.Op(self._default_opset, operator_name), + operator_name, [condition_name or o_cond_var], - [], ) self.ir_builder.add_output( @@ -1262,8 +1264,8 @@ def rename(x): onnx_outputs, "Loop", inputs, - attrs, - sub_functions=sub_functions, + attrs=attrs, + # sub_functions=sub_functions, ) def _translate_block( From 0af0707ef677babdf3038f2ee8adb830cd744759 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Jul 2025 11:23:45 -0700 Subject: [PATCH 20/31] attrs Signed-off-by: Justin Chu --- onnxscript/_converter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 96d3e55046..6ac1518836 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -705,7 +705,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: steps.append(inputs[2]) if len(starts) > 1: - axis_0_attr = self._make_onnx_attr("axis", 0) + axis_0_attr = ir.AttrInt64("axis", 0) start_name = self._generate_unique_name(f"{var_name}_start") self.emit([start_name], "Concat", starts, attrs=[axis_0_attr]) @@ -755,7 +755,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) - axis_attr = self._make_onnx_attr("axis", axis) + axis_attr = ir.AttrInt64("axis", axis) # use Gather to perform indexing # Assign gathered value to either temporary or final target if axis != last_axis: # use temporary to store result of Gather @@ -810,13 +810,13 @@ def _translate_binary_op_expr(self, node: ast.BinOp): # attribute fmod=1 is added in that case. cst = self._eval_constant_expr(node.right) if isinstance(cst, float): - attr = [self._make_onnx_attr("fmod", 1)] + attr = [ir.AttrInt64("fmod", 1)] - op = values.Op(self._default_opset, _PRIMOP_MAP[op]) + onnx_op = _PRIMOP_MAP[op] left, right = self._cast_like_binary_expression( - op, self._translate_expr(node.left), self._translate_expr(node.right) + onnx_op, self._translate_expr(node.left), self._translate_expr(node.right) ) - return op, [left, right], attr + return onnx_op, [left, right], attr def _translate_unary_op_expr(self, node): op = type(node.op) From 2196f99180b764d43046e48c094abed1cb9bd961 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Jul 2025 16:34:35 -0700 Subject: [PATCH 21/31] return values Signed-off-by: Justin Chu --- onnxscript/_converter.py | 54 ++++++++++++++-------------------------- onnxscript/values.py | 1 + 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 6ac1518836..fc3441818a 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -21,21 +21,20 @@ Union, ) -import onnx import onnx_ir as ir from onnxscript.ir import _schemas import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values from onnxscript import type_annotation as ta -from onnxscript._internal import _analysis, ast_utils, autocast, param_manipulation +from onnxscript._internal import _analysis, ast_utils, autocast if TYPE_CHECKING: # The type-alias LocalSymValue represents the types of values that local names in a # script-function may be bound to during translation, (ONNX IR values). # TODO(rama): Rationalize this and values.SymbolValue - LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction] + LocalSymValue = Union[values.SymbolValue, ir.Function] # The type-alias PyValue is used to represent the types of python values that may be used # in an ONNX Script function. @@ -115,28 +114,11 @@ def ignore(cond, msg): } -class Variable: - """Represents an ONNX variable. +_CASTABLE_FIELD = "pkg.onnxscript.converter.castable" - TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this - converter. - """ - - def __init__(self, name: str, castable: bool = False): - """Initialize the instance. - - Args: - name: Name of the ONNX variable - castable: Whether this variable is castable to a desired target type. - Used for ONNX variables representing constants created from python values - like 0 or 1 or 0.5 which are treated as polymorphic values castable to other - types as needed. - """ - self.name = name - self.is_castable = castable - - def __str__(self) -> str: - return self.name +def mark_castable(value: ir.Value): + """Mark an ONNX value as auto-castable.""" + value.meta[_CASTABLE_FIELD] = True @dataclasses.dataclass @@ -227,6 +209,8 @@ def __init__( graph=ir.Graph((), (), nodes=[]), attributes={}, ) + # A mapping from value names to the values for each function + # self._scoped_values: dict[ir.Function, dict[str, ir.Value]] = {} self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: list[dict[str, LocalSymValue]] = [{}] @@ -325,26 +309,25 @@ def _to_onnx_var( target: PreferredName = "tmp", *, info: sourceinfo.SourceInfo, - ) -> Variable: + ) -> ir.Value: """Convert a value to an ONNX variable.""" if isinstance(val, values.AttrRef): # promote attribute to value result = self._generate_unique_name(target) attr = _to_onnx_ref_attr(val, info) - self.emit([], "Constant", [result], attrs=[attr]) + result_val = self.emit([result], "Constant", [], attrs=[attr])[0] if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. result_as_bool = self._generate_unique_name(result + "_as_bool") - self.emit( - [result], - "Cast", + return self.emit( [result_as_bool], + "Cast", + [result], attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], - ) - return Variable(result_as_bool, castable=True) - return Variable(result, castable=True) + )[0] + return result_val if isinstance(val, values.Dynamic): return Variable(val.value) @@ -364,16 +347,17 @@ def emit( *, attrs: Sequence[ir.Attr] = (), domain: str = "", - ): + ) -> Sequence[ir.Value]: """Emit an ONNX operator with the given inputs, outputs, and attributes.""" node = ir.Node( domain=domain, op_type=op_type, - inputs=[self._lookup(inp, self._source_of(inputs[0])) for inp in inputs], + inputs=[self._lookup(inp, self._source_of(inp)) for inp in inputs], attributes=attrs, - outputs=[self._lookup(out, self._source_of(outputs[0])) for out in outputs], + outputs=[self._lookup(out, self._source_of(out)) for out in outputs], ) self._current_fn.append(node) + return node.outputs def _emit_const( self, diff --git a/onnxscript/values.py b/onnxscript/values.py index 7df8a6ce1c..a1c6322be9 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -746,6 +746,7 @@ def __init__( raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") self.typeinfo = typeinfo + class DynamicKind(IntFlag): Unknown = 0 Input = 1 From b56a161643c2085737c1285dd4227da59ff22058 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 24 Jul 2025 10:50:19 -0700 Subject: [PATCH 22/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index fc3441818a..0c246bf930 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -385,6 +385,7 @@ def _emit_const( fail(info.msg(str(e))) self.emit([], "Constant", [var_name], attrs=[ir.AttrTensor("value", tensor)]) + # TODO: I am here return Variable(var_name, True) def _emit_copy(self, original_var: str, suggested_name: str) -> str: From b04a45560b4682a70f73566520c3ffec5ef1a2d2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 24 Jul 2025 17:47:28 -0700 Subject: [PATCH 23/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 0c246bf930..8a01931e1b 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -25,7 +25,7 @@ from onnxscript.ir import _schemas import onnxscript -from onnxscript import irbuilder, onnx_types, sourceinfo, values +from onnxscript import onnx_types, sourceinfo, values from onnxscript import type_annotation as ta from onnxscript._internal import _analysis, ast_utils, autocast @@ -310,7 +310,7 @@ def _to_onnx_var( *, info: sourceinfo.SourceInfo, ) -> ir.Value: - """Convert a value to an ONNX variable.""" + """Convert a Python or symbolic value to an ONNX Value.""" if isinstance(val, values.AttrRef): # promote attribute to value result = self._generate_unique_name(target) @@ -330,7 +330,8 @@ def _to_onnx_var( return result_val if isinstance(val, values.Dynamic): - return Variable(val.value) + # A value in ONNX + return ir.Value(name=val.value) # Assume value is a python-value convertible to a tensor return self._emit_const(val, target, info) From e4c95767dae62da60913919e95b89c3831169e25 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 25 Jul 2025 12:33:30 -0700 Subject: [PATCH 24/31] wip Signed-off-by: Justin Chu --- onnxscript/_converter.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 8a01931e1b..333902a598 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -365,7 +365,7 @@ def _emit_const( pyvalue: PyValue, suggested_name: PreferredName | None, info: sourceinfo.SourceInfo, - ) -> Variable: + ) -> ir.Value: """Emit a constant value as an ONNX Constant node.""" # Obtain a name for the constant if suggested_name is None: @@ -385,15 +385,14 @@ def _emit_const( except Exception as e: fail(info.msg(str(e))) - self.emit([], "Constant", [var_name], attrs=[ir.AttrTensor("value", tensor)]) - # TODO: I am here - return Variable(var_name, True) + const = self.emit([var_name], "Constant", [], attrs=[ir.AttrTensor("value", tensor)])[0] + mark_castable(const) + return const - def _emit_copy(self, original_var: str, suggested_name: str) -> str: + def _emit_copy(self, original_var: str, suggested_name: str) -> ir.Value: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self._generate_unique_name(suggested_name) - self.emit([original_var], "Identity", [new_var]) - return new_var + return self.emit([new_var], "Identity", [original_var])[0] def _eval_constant_expr(self, expr: ast.AST) -> PyValue: """Evaluates a sub-expression that is assumed to represent a constant value. @@ -1002,10 +1001,11 @@ def check_num_outputs(n): ) ) - def ret(exp, i, suffix): + def ret(exp: ast.AST, i: int, suffix: str) -> str: preferred_name = f"return_val{suffix}" return_var = self._translate_expr(exp, preferred_name).name val = self._lookup(return_var, self._source_of(exp), False) + assert type(val) is values.Dynamic if val and val.kind == values.DynamicKind.Input: # In ONNX, a graph-input cannot be an output of the graph. # We need to insert a copy. @@ -1013,13 +1013,17 @@ def ret(exp, i, suffix): for prev_output in self._current_fn.outputs: if prev_output.name == return_var: # ONNX does not allow duplicate output names. + # TODO(justinchuby): Maybe pass in ir.Value in _emit_copy return_var = self._emit_copy(return_var, f"{return_var}_copy") break if self.returntype is None: t = None else: t = self.returntype[i] - self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) + self._current_fn.outputs.append(return_var) + # TODO(justinchuby): Set type for return var from t + # TODO(justinchuby): Get self._source_of(stmt) + # self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) return return_var val = stmt.value From 11a735ef3fdd13c52a0a1f25e325e533b2f04cb7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Jul 2025 11:53:37 -0700 Subject: [PATCH 25/31] refactor Signed-off-by: Justin Chu --- onnxscript/_converter.py | 151 +++++++++++++++++++++++++++++++-------- onnxscript/values.py | 87 ---------------------- 2 files changed, 120 insertions(+), 118 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 333902a598..b189111d95 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -29,37 +29,6 @@ from onnxscript import type_annotation as ta from onnxscript._internal import _analysis, ast_utils, autocast -if TYPE_CHECKING: - # The type-alias LocalSymValue represents the types of values that local names in a - # script-function may be bound to during translation, (ONNX IR values). - # TODO(rama): Rationalize this and values.SymbolValue - - LocalSymValue = Union[values.SymbolValue, ir.Function] - - # The type-alias PyValue is used to represent the types of python values that may be used - # in an ONNX Script function. - # TODO(rama): Flesh out the set of valid types here. These include values such as - # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for - # use as value-parameters or attribute-parameters in an ONNX call (Node). - - PyValue = Any - - # The type-alias SymValue denotes values that an identifier may be bound to during - # translation. A local name will be bound to a LocalSymValue, while a global name - # will be bound to a PyValue. - - SymValue = Union[LocalSymValue, PyValue] - - # PreferredName is a type-alias used to represent the preferred name used in the generated - # ONNX for a value returned by an expression. There is no guarantee that the specified - # name will be used exactly. The converter will modify the name (with a suffix), - # if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). - - PreferredName = str - - # The type-alias OnnxVar indicates variable names used in the generated ONNX. - OnnxVarName = str - logger = logging.getLogger(__name__) @@ -116,6 +85,126 @@ def ignore(cond, msg): _CASTABLE_FIELD = "pkg.onnxscript.converter.castable" + + +class SymbolValue: + """Represents script-time value information about named variables used in a script. + + At translation-time, the (local) variables of a script, including its parameters, + are bound to a SymbolValue. + + SymbolValues fall into the following categories: + + AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes + + Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs. + Dynamic values include input-parameters of the script, as well intermediate + values computed in the script. + + For example, consider the following script definition: + :: + + @script() + def ThresholdedRelu(X, alpha: float): + zero = op.CastLike(0, X) + return op.Where(X > alpha, X, zero) + + Here, `X` has a Dynamic value, `alpha` has an AttrRef value, and `zero` + has a Dynamic value. + + Scripts may also contain references to global variables, but the translator + does not associate a SymbolValue with them. The python value of global variables + is used directly in the translation, and such global variables are intended + to be used for limited purposes, namely: + * To identify an opset + * To represent constant-values, translated into ONNX constants. + """ + + def __init__(self, info: sourceinfo.SourceInfo) -> None: + if not isinstance(info, sourceinfo.SourceInfo): + raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.") + self.info = info + + +class AttrRef(SymbolValue): + def __init__( + self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo + ) -> None: + """Initializes AttrRef. + + Arguments: + attr_name: name of the attribute-parameter + typeinfo: type annotation of the attribute. + op's attributes in ONNX are usually single type or list of single type. + info: for debugging use. + """ + super().__init__(info) + self.value = attr_name + + if not isinstance(typeinfo, (type, _GenericAlias)): + # typing._GenericAlias for List[int] and List[str], etc. + raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") + self.typeinfo = typeinfo + + +class DynamicKind(IntFlag): + Unknown = 0 + Input = 1 + Output = 2 + Intermediate = 4 + Loop = 8 + + +class Dynamic(SymbolValue): + def __init__( + self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None + ) -> None: + """Initializes Dynamic. + + Arguments: + onnx_var: the name of the ONNX variable used to represent this value + kind: the DynamicKind of this variable + info: source-location information for error-messages/debugging + typeinfo: type-information for the value + """ + super().__init__(info) + assert isinstance(kind, DynamicKind) + self.value = onnx_var + self.kind = kind + self.typeinfo = typeinfo + + +# The type-alias LocalSymValue represents the types of values that local names in a +# script-function may be bound to during translation, (ONNX IR values). +# TODO(rama): Rationalize this and values.SymbolValue + +LocalSymValue = Union[SymbolValue, ir.Function] + +# The type-alias PyValue is used to represent the types of python values that may be used +# in an ONNX Script function. +# TODO(rama): Flesh out the set of valid types here. These include values such as +# 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for +# use as value-parameters or attribute-parameters in an ONNX call (Node). + +PyValue = Any + +# The type-alias SymValue denotes values that an identifier may be bound to during +# translation. A local name will be bound to a LocalSymValue, while a global name +# will be bound to a PyValue. + +SymValue = Union[LocalSymValue, PyValue] + +# PreferredName is a type-alias used to represent the preferred name used in the generated +# ONNX for a value returned by an expression. There is no guarantee that the specified +# name will be used exactly. The converter will modify the name (with a suffix), +# if necesssary, to ensure that it is unique (to ensure ONNX's SSA requirement). + +PreferredName = str + +# The type-alias OnnxVar indicates variable names used in the generated ONNX. +OnnxVarName = str + + def mark_castable(value: ir.Value): """Mark an ONNX value as auto-castable.""" value.meta[_CASTABLE_FIELD] = True diff --git a/onnxscript/values.py b/onnxscript/values.py index a1c6322be9..80c7326c66 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -685,90 +685,3 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # argument order from the Python function definition, which is lost in OpSchema. self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas - - -class SymbolValue: - """Represents script-time value information about named variables used in a script. - - At translation-time, the (local) variables of a script, including its parameters, - are bound to a SymbolValue. - - SymbolValues fall into the following categories: - - AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes - - Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs. - Dynamic values include input-parameters of the script, as well intermediate - values computed in the script. - - For example, consider the following script definition: - :: - - @script() - def ThresholdedRelu(X, alpha: float): - zero = op.CastLike(0, X) - return op.Where(X > alpha, X, zero) - - Here, `X` has a Dynamic value, `alpha` has an AttrRef value, and `zero` - has a Dynamic value. - - Scripts may also contain references to global variables, but the translator - does not associate a SymbolValue with them. The python value of global variables - is used directly in the translation, and such global variables are intended - to be used for limited purposes, namely: - * To identify an opset - * To represent constant-values, translated into ONNX constants. - """ - - def __init__(self, info: sourceinfo.SourceInfo) -> None: - if not isinstance(info, sourceinfo.SourceInfo): - raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.") - self.info = info - - -class AttrRef(SymbolValue): - def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo - ) -> None: - """Initializes AttrRef. - - Arguments: - attr_name: name of the attribute-parameter - typeinfo: type annotation of the attribute. - op's attributes in ONNX are usually single type or list of single type. - info: for debugging use. - """ - super().__init__(info) - self.value = attr_name - - if not isinstance(typeinfo, (type, _GenericAlias)): - # typing._GenericAlias for List[int] and List[str], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") - self.typeinfo = typeinfo - - -class DynamicKind(IntFlag): - Unknown = 0 - Input = 1 - Output = 2 - Intermediate = 4 - Loop = 8 - - -class Dynamic(SymbolValue): - def __init__( - self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None - ) -> None: - """Initializes Dynamic. - - Arguments: - onnx_var: the name of the ONNX variable used to represent this value - kind: the DynamicKind of this variable - info: source-location information for error-messages/debugging - typeinfo: type-information for the value - """ - super().__init__(info) - assert isinstance(kind, DynamicKind) - self.value = onnx_var - self.kind = kind - self.typeinfo = typeinfo From 0cd1f207d85d9a7f55cc6a4c79a214128e737d02 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Jul 2025 17:30:28 -0700 Subject: [PATCH 26/31] _ValueEnvironment Signed-off-by: Justin Chu --- onnxscript/_converter.py | 77 +++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index b189111d95..7adbe6fab9 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -19,6 +19,7 @@ Sequence, Tuple, Union, + _GenericAlias ) import onnx_ir as ir @@ -222,6 +223,49 @@ class ASTMeta: live_in: set[str] | None = None +class _ValueEnvironment: + def __init__(self, converter: Converter): + self._sym_value_to_onnx_values: dict[SymbolValue, ir.Value] = {} + self._converter = converter + + def get_or_create_value( + self, val: SymbolValue, info: sourceinfo.SourceInfo + ) -> ir.Value: + """Get or create an ONNX Value for a SymbolValue.""" + if val in self._sym_value_to_onnx_values: + return self._sym_value_to_onnx_values[val] + if isinstance(val, AttrRef): + # promote attribute to value + result_name = self._converter._generate_unique_name("v") + attr = _to_onnx_ref_attr(val, info) + result = self._converter.emit([result_name], "Constant", [], attrs=[attr])[0] + if ta.base_type_is_bool(val.typeinfo): + # ONNX attributes use an int-encoding for bools, but ONNX tensor types + # distinguish between int and bool. So we cast the int tensor to a bool tensor, + # to promote a (python) bool attribute to a ONNX bool tensor. + result_as_bool_name = self._converter._generate_unique_name(f"{result_name}_as_bool") + result = self._converter.emit( + [result_as_bool_name], + "Cast", + [result_name], + attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], + )[0] + + self._sym_value_to_onnx_values[val] = result + return result + + if isinstance(val, Dynamic): + # A value in ONNX + result = ir.Value(name=val.value) + self._sym_value_to_onnx_values[val] = result + return result + + # Assume value is a python-value convertible to a tensor + result = self._converter._emit_const(val, None, info) + self._sym_value_to_onnx_values[val] = result + return result + + class Converter: """Main class to translate python code into ONNX operators. @@ -392,39 +436,6 @@ def _generate_unique_name(self, candidate: str = "tmp") -> str: self._used_vars.add(r) return r - def _to_onnx_var( - self, - val: values.SymbolValue | PyValue, - target: PreferredName = "tmp", - *, - info: sourceinfo.SourceInfo, - ) -> ir.Value: - """Convert a Python or symbolic value to an ONNX Value.""" - if isinstance(val, values.AttrRef): - # promote attribute to value - result = self._generate_unique_name(target) - attr = _to_onnx_ref_attr(val, info) - result_val = self.emit([result], "Constant", [], attrs=[attr])[0] - if ta.base_type_is_bool(val.typeinfo): - # ONNX attributes use an int-encoding for bools, but ONNX tensor types - # distinguish between int and bool. So we cast the int tensor to a bool tensor, - # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self._generate_unique_name(result + "_as_bool") - return self.emit( - [result_as_bool], - "Cast", - [result], - attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], - )[0] - return result_val - - if isinstance(val, values.Dynamic): - # A value in ONNX - return ir.Value(name=val.value) - - # Assume value is a python-value convertible to a tensor - return self._emit_const(val, target, info) - def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: """Convert a python variable to an ONNX variable.""" return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) From d98e8b5e6998e29fd61b481625627715c4a14cd8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Jul 2025 17:31:13 -0700 Subject: [PATCH 27/31] emit_const Signed-off-by: Justin Chu --- onnxscript/_converter.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 7adbe6fab9..36d9acaa9d 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -261,7 +261,7 @@ def get_or_create_value( return result # Assume value is a python-value convertible to a tensor - result = self._converter._emit_const(val, None, info) + result = self._converter.emit_const(val, None, info) self._sym_value_to_onnx_values[val] = result return result @@ -348,6 +348,7 @@ def __init__( self._used_vars: set[str] = set() self._locals: list[dict[str, LocalSymValue]] = [{}] self._finalized = False + self._value_env = _ValueEnvironment(self) self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta) # def _init_function_translation(self) -> None: @@ -460,7 +461,7 @@ def emit( self._current_fn.append(node) return node.outputs - def _emit_const( + def emit_const( self, pyvalue: PyValue, suggested_name: PreferredName | None, @@ -600,7 +601,7 @@ def _translate_expr( elif isinstance(node, ast.Subscript): r = self._translate_subscript_expr(node, target) elif _is_constant_expr(node): - r = self._emit_const(self._eval_constant_expr(node), target, self._source_of(node)) + r = self.emit_const(self._eval_constant_expr(node), target, self._source_of(node)) else: raise ValueError( self._message(node, f"Unsupported expression type {type(node)!r}.") @@ -676,7 +677,7 @@ def _translate_subscript_expr( def const_1d(value, name: Optional[str] = None): nonlocal cached_int_consts if value not in cached_int_consts: - cached_int_consts[value] = self._emit_const([value], name, info) + cached_int_consts[value] = self.emit_const([value], name, info) return cached_int_consts[value] def one_1d(): @@ -815,7 +816,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: "Slice", [var_name, start_name, end_name, axes_name, steps_name], ) - squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) + squeezed_axes = self.emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze result = self._generate_unique_name(f"{var_name}_squeezed") @@ -1231,7 +1232,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: outputs = list(loop_state_vars | scan_outputs) # loop-condition: - # o_loop_condition = self._emit_const(True, "true", self._source_of(loop_stmt)) + # o_loop_condition = self.emit_const(True, "true", self._source_of(loop_stmt)) # build loop_body self._enter_scope("loop_body", loop_stmt) From f1c10c6384e353b1a774e613d376888643945d72 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 17:26:51 -0700 Subject: [PATCH 28/31] _ValueEnvironment Signed-off-by: Justin Chu --- onnxscript/_converter.py | 141 +++++++++------------------------------ 1 file changed, 33 insertions(+), 108 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 36d9acaa9d..1fc4abb3f7 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -85,69 +85,10 @@ def ignore(cond, msg): _CASTABLE_FIELD = "pkg.onnxscript.converter.castable" +_SOURCEINFO_FIELD = "pkg.onnxscript.sourceinfo" -class SymbolValue: - """Represents script-time value information about named variables used in a script. - - At translation-time, the (local) variables of a script, including its parameters, - are bound to a SymbolValue. - - SymbolValues fall into the following categories: - - AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes - - Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs. - Dynamic values include input-parameters of the script, as well intermediate - values computed in the script. - - For example, consider the following script definition: - :: - - @script() - def ThresholdedRelu(X, alpha: float): - zero = op.CastLike(0, X) - return op.Where(X > alpha, X, zero) - - Here, `X` has a Dynamic value, `alpha` has an AttrRef value, and `zero` - has a Dynamic value. - - Scripts may also contain references to global variables, but the translator - does not associate a SymbolValue with them. The python value of global variables - is used directly in the translation, and such global variables are intended - to be used for limited purposes, namely: - * To identify an opset - * To represent constant-values, translated into ONNX constants. - """ - - def __init__(self, info: sourceinfo.SourceInfo) -> None: - if not isinstance(info, sourceinfo.SourceInfo): - raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.") - self.info = info - - -class AttrRef(SymbolValue): - def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo - ) -> None: - """Initializes AttrRef. - - Arguments: - attr_name: name of the attribute-parameter - typeinfo: type annotation of the attribute. - op's attributes in ONNX are usually single type or list of single type. - info: for debugging use. - """ - super().__init__(info) - self.value = attr_name - - if not isinstance(typeinfo, (type, _GenericAlias)): - # typing._GenericAlias for List[int] and List[str], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") - self.typeinfo = typeinfo - - class DynamicKind(IntFlag): Unknown = 0 Input = 1 @@ -155,31 +96,11 @@ class DynamicKind(IntFlag): Intermediate = 4 Loop = 8 - -class Dynamic(SymbolValue): - def __init__( - self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None - ) -> None: - """Initializes Dynamic. - - Arguments: - onnx_var: the name of the ONNX variable used to represent this value - kind: the DynamicKind of this variable - info: source-location information for error-messages/debugging - typeinfo: type-information for the value - """ - super().__init__(info) - assert isinstance(kind, DynamicKind) - self.value = onnx_var - self.kind = kind - self.typeinfo = typeinfo - - # The type-alias LocalSymValue represents the types of values that local names in a # script-function may be bound to during translation, (ONNX IR values). # TODO(rama): Rationalize this and values.SymbolValue -LocalSymValue = Union[SymbolValue, ir.Function] +LocalSymValue = Union[ir.Value, ir.Attr, ir.Function] # The type-alias PyValue is used to represent the types of python values that may be used # in an ONNX Script function. @@ -187,7 +108,7 @@ def __init__( # 1 (int), 1.0 (float), [2, 4], [1.0], etc. which will be converted to ONNX, for # use as value-parameters or attribute-parameters in an ONNX call (Node). -PyValue = Any +PyValue = Union[int, float, str, bool, Sequence[int], Sequence[float], Sequence[str], Sequence[bool]] # The type-alias SymValue denotes values that an identifier may be bound to during # translation. A local name will be bound to a LocalSymValue, while a global name @@ -210,6 +131,10 @@ def mark_castable(value: ir.Value): """Mark an ONNX value as auto-castable.""" value.meta[_CASTABLE_FIELD] = True +def set_sourceinfo(value: ir.Value, info: sourceinfo.SourceInfo): + """Set the source information for an ONNX value.""" + value.meta[_SOURCEINFO_FIELD] = info + @dataclasses.dataclass class ASTMeta: @@ -225,45 +150,45 @@ class ASTMeta: class _ValueEnvironment: def __init__(self, converter: Converter): - self._sym_value_to_onnx_values: dict[SymbolValue, ir.Value] = {} + self._py_var_name_to_ir_values: dict[str, ir.Value] = {} + self._py_var_name_to_ir_attr_refs: dict[str, ir.Attr] = {} + self._py_var_name_to_py_values: dict[str, PyValue] = {} self._converter = converter def get_or_create_value( - self, val: SymbolValue, info: sourceinfo.SourceInfo + self, var: str, info: sourceinfo.SourceInfo ) -> ir.Value: - """Get or create an ONNX Value for a SymbolValue.""" - if val in self._sym_value_to_onnx_values: - return self._sym_value_to_onnx_values[val] - if isinstance(val, AttrRef): + """Get or create an IR value from Python variable name.""" + if var in self._py_var_name_to_ir_values: + return self._py_var_name_to_ir_values[var] + if var in self._py_var_name_to_ir_attr_refs: # promote attribute to value - result_name = self._converter._generate_unique_name("v") - attr = _to_onnx_ref_attr(val, info) - result = self._converter.emit([result_name], "Constant", [], attrs=[attr])[0] - if ta.base_type_is_bool(val.typeinfo): + attr = self._py_var_name_to_ir_attr_refs[var] + result = self._converter.op( + "Constant", [], attrs=[attr] + ) + if is_base_type_bool(attr): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool_name = self._converter._generate_unique_name(f"{result_name}_as_bool") - result = self._converter.emit( - [result_as_bool_name], + result = self._converter.op( "Cast", - [result_name], + [result], attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], - )[0] - - self._sym_value_to_onnx_values[val] = result - return result + ) - if isinstance(val, Dynamic): - # A value in ONNX - result = ir.Value(name=val.value) - self._sym_value_to_onnx_values[val] = result + self._py_var_name_to_ir_values[var] = result return result + if var in self._py_var_name_to_py_values: + # Assume value is a python-value convertible to a tensor + result = self._converter.op( + "Constant", [], attrs=[ir.AttrTensor("value", ir.tensor(var, name=var))] + ) + mark_castable(result) + self._py_var_name_to_ir_values[var] = result - # Assume value is a python-value convertible to a tensor - result = self._converter.emit_const(val, None, info) - self._sym_value_to_onnx_values[val] = result - return result + # TODO(justinchuby): Update error message + raise ValueError(f"Variable '{var}' is unbound.") class Converter: From ce2428f07592356bd135352441a3a30de153c088 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 15 Aug 2025 17:38:50 -0700 Subject: [PATCH 29/31] fixme Signed-off-by: Justin Chu --- onnxscript/_converter.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 1fc4abb3f7..dbce5a1c01 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -136,6 +136,11 @@ def set_sourceinfo(value: ir.Value, info: sourceinfo.SourceInfo): value.meta[_SOURCEINFO_FIELD] = info +def is_base_type_bool(attr: ir.Attr) -> bool: + """Check if the attribute is a boolean type.""" + # FIXME: Add meta to attributes + attr.meta[_SOURCEINFO_FIELD] + @dataclasses.dataclass class ASTMeta: """Metadata for an AST node. @@ -276,20 +281,6 @@ def __init__( self._value_env = _ValueEnvironment(self) self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta) - # def _init_function_translation(self) -> None: - # """Initialize self for translating a new (top-level) function.""" - # self._outer = [] - # # TODO(justinchuby): Update this - # self._current_fn = ir.Function( - # domain=self._opset.domain, - # name="", - # graph=ir.Graph((), (), nodes=[]), - # attributes={}, - # ) - # self._nextvar = 0 - # self._used_vars = set() - # self._locals: List[Dict[str, LocalSymValue]] = [{}] - def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) From bba0ff69f95b478d871d78792a84a3ab1462f3c8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 29 Oct 2025 16:30:03 -0700 Subject: [PATCH 30/31] Merge Signed-off-by: Justin Chu --- onnxscript/_converter.py | 9 ++------- onnxscript/_internal/{_analysis.py => analysis.py} | 6 +----- .../_internal/{_analysis_test.py => analysis_test.py} | 4 ++-- onnxscript/ir/_schemas.py | 1 - onnxscript/irbuilder.py | 1 + onnxscript/values.py | 4 +--- 6 files changed, 7 insertions(+), 18 deletions(-) rename onnxscript/_internal/{_analysis.py => analysis.py} (98%) rename onnxscript/_internal/{_analysis_test.py => analysis_test.py} (98%) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 7a054b13bb..53fd4258be 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -5,31 +5,26 @@ from __future__ import annotations import ast -from collections import defaultdict import dataclasses import logging +from collections import defaultdict from typing import ( - TYPE_CHECKING, Any, Dict, List, - Mapping, NoReturn, Optional, Sequence, Tuple, Union, - _GenericAlias ) import onnx_ir as ir -from onnxscript.ir import _schemas import onnxscript from onnxscript import onnx_types, sourceinfo, values from onnxscript import type_annotation as ta -from onnxscript._internal import _analysis, ast_utils, autocast - +from onnxscript._internal import analysis, ast_utils, autocast logger = logging.getLogger(__name__) diff --git a/onnxscript/_internal/_analysis.py b/onnxscript/_internal/analysis.py similarity index 98% rename from onnxscript/_internal/_analysis.py rename to onnxscript/_internal/analysis.py index ad03e95679..0641343729 100644 --- a/onnxscript/_internal/_analysis.py +++ b/onnxscript/_internal/analysis.py @@ -4,15 +4,11 @@ from __future__ import annotations import ast -from typing import Any, Optional, Sequence, TYPE_CHECKING -from collections import defaultdict +from typing import Any, Optional, Sequence from onnxscript import sourceinfo from onnxscript._internal import ast_utils -if TYPE_CHECKING: - from onnxscript import _converter - def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: if not isinstance(for_stmt.target, ast.Name): diff --git a/onnxscript/_internal/_analysis_test.py b/onnxscript/_internal/analysis_test.py similarity index 98% rename from onnxscript/_internal/_analysis_test.py rename to onnxscript/_internal/analysis_test.py index ee418fb889..7a7e5feaa0 100644 --- a/onnxscript/_internal/_analysis_test.py +++ b/onnxscript/_internal/analysis_test.py @@ -6,7 +6,7 @@ import unittest from typing import Any -from onnxscript._internal import _analysis, ast_utils +from onnxscript._internal import analysis, ast_utils from onnxscript.onnx_opset import opset15 as op from onnxscript.sourceinfo import formatter @@ -33,7 +33,7 @@ def generic_visit(self, node): class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): source, parse_tree = ast_utils.get_src_and_ast(fun) - analyzer = _analysis.AstAnalyzer(parse_tree, formatter(source)) + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) visitor = AnalysisResultsVisitor(analyzer) visitor.visit(parse_tree) return visitor.results diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index 2a2527e31b..ceebbd807f 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -3,7 +3,6 @@ from __future__ import annotations import collections.abc -import copy import dataclasses import inspect import logging diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 8674a6331f..366b6fd1c5 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -194,6 +194,7 @@ class IRStmt: - `sub_functions`: A dictionary of sub-functions that this statement may call, mapping function names to `onnx.FunctionProto` instances. """ + def __init__( self, result: Sequence[str], diff --git a/onnxscript/values.py b/onnxscript/values.py index 80c7326c66..56f04d155a 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -8,7 +8,6 @@ import logging import types import typing -from enum import IntFlag from typing import ( # type: ignore[attr-defined] Any, Callable, @@ -18,14 +17,13 @@ Protocol, Sequence, TypeVar, - _GenericAlias, ) import onnx import onnx.defs from typing_extensions import ParamSpec -from onnxscript import _converter, irbuilder, sourceinfo, type_annotation +from onnxscript import _converter, irbuilder, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas From 7a0b016f041662a8e4a2243876987711041c9ae2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 29 Oct 2025 16:42:17 -0700 Subject: [PATCH 31/31] WIP Signed-off-by: Justin Chu --- onnxscript/_converter.py | 220 +++++++++++++++++++++++++++++++-------- 1 file changed, 176 insertions(+), 44 deletions(-) diff --git a/onnxscript/_converter.py b/onnxscript/_converter.py index 53fd4258be..a19d005857 100644 --- a/onnxscript/_converter.py +++ b/onnxscript/_converter.py @@ -5,13 +5,14 @@ from __future__ import annotations import ast +from collections import defaultdict import dataclasses import logging -from collections import defaultdict from typing import ( Any, Dict, List, + Mapping, NoReturn, Optional, Sequence, @@ -20,12 +21,14 @@ ) import onnx_ir as ir +from onnxscript.ir import _schemas import onnxscript from onnxscript import onnx_types, sourceinfo, values from onnxscript import type_annotation as ta from onnxscript._internal import analysis, ast_utils, autocast + logger = logging.getLogger(__name__) @@ -243,28 +246,14 @@ def __init__( if global_names is not None: # We make a copy in case function eval modifies it. - self.globals = global_names.copy() - self.this_module = opset - self.default_opset_ = default_opset - - # States initialized by `_init_function_translation` - self._outer: List[irbuilder.IRFunction] = [] - self._current_fn: irbuilder.IRFunction = None - self._nextvar: int = 0 - self._used_vars: set[str] = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] - self._analyzer: analysis.AstAnalyzer | None = None - - @property - def analyzer(self) -> analysis.AstAnalyzer: - if self._analyzer is None: - raise RuntimeError("Analyzer not initialized.") - return self._analyzer + self._globals = global_names.copy() + else: + self._globals = {} - @property - def default_opset(self) -> values.Opset: - if self.default_opset_ is None: - raise RuntimeError( + self._source = source + self._default_opset = default_opset or _find_onnx_opset(root, self._globals) + if self._default_opset is None: + raise ValueError( "default_opset must be specified in script for functions " "that do not contain any use of an ONNX opset." ) @@ -288,8 +277,15 @@ def default_opset(self) -> values.Opset: self._locals: list[dict[str, LocalSymValue]] = [{}] self._finalized = False self._value_env = _ValueEnvironment(self) + self._analyzer: analysis.AstAnalyzer | None = None self.meta: defaultdict[ast.AST, ASTMeta] = defaultdict(ASTMeta) + @property + def analyzer(self) -> analysis.AstAnalyzer: + if self._analyzer is None: + raise RuntimeError("Analyzer not initialized.") + return self._analyzer + def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self._source, self._current_fn.name) @@ -1427,27 +1423,163 @@ def _translate_function_def(self, node: ast.FunctionDef) -> ir.Function: # Update docstring if available if docstring := ast.get_docstring(node): self._current_fn.doc_string = docstring + + def _finalize(self) -> None: + self._analyzer = None + self._finalized = True + + def convert(self) -> ir.Function: + """Convert the Python AST to an ONNX IR function.""" + if self._finalized: + return self._current_fn + + func_def = self._ast_root + self._analyzer = analysis.AstAnalyzer(func_def, self._message, self.globals) + self._translate_function_def(func_def) + self._finalize() return self._current_fn + # TODO(justinchuby): Handle function registration to the opset + # self._opset.add_function_def(fn_ir) + + +def _is_constant_expr(node: ast.AST) -> bool: + """Check if the AST node is a constant expression.""" + if isinstance(node, ast.UnaryOp): + return _is_constant_expr(node.operand) + if isinstance( + node, + ( + ast.Call, + ast.BinOp, + ast.UnaryOp, + ast.Compare, + ast.Attribute, + ast.List, + ast.Load, + ast.Constant, + ), + ): + return all(_is_constant_expr(c) for c in ast.iter_child_nodes(node)) + return False + - def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: - if isinstance(stmt, ast.FunctionDef): - self._init_function_translation() - if self.default_opset_ is None: - opset = self._find_onnx_opset(stmt) - if opset: - self._set_default_opset(opset, stmt) - domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) - self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) - fn_ir = self._translate_function_def_common(stmt) - fn_ir.debug_print() - self.this_module.add_function_def(fn_ir) - self._analyzer = None - return fn_ir - raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") - - def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: - """Translate a (top-level) function signature.""" - domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(fn.name, domain, True) - return self._translate_function_signature_common(fn) +def _separate_inputs_and_attrs( + signature: _schemas.OpSignature, + args: Sequence[ast.expr], + kwargs: Mapping[str, ast.expr], +) -> tuple[Sequence[ast.expr], dict[str, ast.expr]]: + """Construct two mappings: name to inputs and named to attributes based on the signature and args/kwargs. + + This function uses the OpSignature to determine which argument in args and kwargs corresponds to + which parameter in the signature. ONNX node inputs are stored in named_inputs, and attributes are + stored in named_attrs. If an _optional input_ is not provided, it is filled with None. + + Args: + signature: The OpSignature for the node. + args: The positional arguments for the node. + kwargs: The keyword arguments for the node. + + Returns: + A tuple of two mappings: named_inputs and named_attrs. + + Raises: + ValueError: If a required parameter is not provided. + """ + # 1. Construct inputs, attrs based on (args, kwargs) and the signature. + # a. Loop over all parameters in the signature and args together + # b. Depending on param.is_input, Record inputs or named_attrs[param.name] = arg + # c. Handle kwargs as well + inputs_reversed: Sequence[Any] = [] + named_attrs: dict[str, Any] = {} + reversed_args_stack = list(reversed(args)) + for param in signature.params: + if isinstance(param, _schemas.Parameter): + # Handle inputs + if reversed_args_stack: + # First exhaust the positional arguments + if param.variadic: + # Handle variadic arguments + inputs_reversed = [*reversed(args)] + reversed_args_stack.clear() + else: + inputs_reversed.append(reversed_args_stack.pop()) + elif param.name in kwargs: + inputs_reversed.append(kwargs[param.name]) + elif param.required: + raise ValueError( + f"Required parameter '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional parameter '%s' is not provided. Added as None. Signature: %s", + param.name, + signature, + ) + inputs_reversed.append(None) + else: + # Handle attributes + attribute: ir.Attr | None + assert isinstance(param, _schemas.AttributeParameter), ( + f"Expected AttributeParameter, got {type(param)}" + ) + if reversed_args_stack: + # First exhaust the positional arguments + attribute = reversed_args_stack.pop() # type: ignore[assignment] + elif kwargs.get(param.name) is not None: + attribute = kwargs[param.name] # type: ignore[assignment] + else: + if param.required: + raise ValueError( + f"Required attribute '{param.name}' is not provided. " + f"Signature: {signature}. Args: {args}. Kwargs: {kwargs}." + ) + else: + logger.debug( + "Optional attribute '%s' is None. Dropped. Signature: %s", + param.name, + signature, + ) + continue + named_attrs[param.name] = attribute + return tuple(reversed(inputs_reversed)), named_attrs + + +def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr: + """Convert an attribute reference to an ONNX ref attribute.""" + + # TODO(justinchuby): Consider using a convenience function + pytype = val.typeinfo + attrtype = _schemas.get_attr_type(pytype) + attrname = None + if attrtype is ir.AttributeType.FLOAT: + attrname = "value_float" + elif attrtype is ir.AttributeType.INT: + attrname = "value_int" + elif attrtype is ir.AttributeType.STRING: + attrname = "value_string" + elif attrtype is ir.AttributeType.INTS: + attrname = "value_ints" + else: + msg = f"Unsupported attribute type {pytype!r}." + fail(info.msg(msg) if info else msg) + # TODO(justinchuby): What is the ref attr name? + return ir.RefAttr(attrname, val.value, attrtype) + + +def _find_onnx_opset(node: ast.AST, globals: dict[str, Any]) -> values.Opset | None: + """Find the (first) ONNX opset used in the function, if any.""" + # Search for a Call expression of form "op.OpName(...)" + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + opset_expr = node.func.value + if isinstance(opset_expr, ast.Name): + if opset_expr.id in globals: + opset = globals[opset_expr.id] + if isinstance(opset, values.Opset) and opset.domain == "": + return opset + for child in ast.iter_child_nodes(node): + res = _find_onnx_opset(child, globals) + if res is not None: + return res + return None