Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
69b08fb
Rename converter
justinchuby Jul 15, 2025
7e0a767
Update fields
justinchuby Jul 15, 2025
8c7abc6
wip
justinchuby Jul 15, 2025
370a9f5
make locals a stack
justinchuby Jul 15, 2025
10473a9
wip
justinchuby Jul 15, 2025
3889e70
wip
justinchuby Jul 16, 2025
e09fdcc
update
justinchuby Jul 16, 2025
fa60de6
wip
justinchuby Jul 16, 2025
3aff171
wip
justinchuby Jul 16, 2025
761451a
wip
justinchuby Jul 18, 2025
882af66
wip
justinchuby Jul 18, 2025
852cc42
update
justinchuby Jul 18, 2025
be610e9
continue
justinchuby Jul 18, 2025
98754ab
Update analysis pass
justinchuby Jul 19, 2025
22eeddc
live_out
justinchuby Jul 19, 2025
c74b854
Update if
justinchuby Jul 19, 2025
2c43e9a
fmt
justinchuby Jul 19, 2025
c53b8f3
Update emit ordering
justinchuby Jul 22, 2025
c84ad91
update emit call
justinchuby Jul 22, 2025
0af0707
attrs
justinchuby Jul 22, 2025
2196f99
return values
justinchuby Jul 22, 2025
b56a161
wip
justinchuby Jul 24, 2025
359eb0b
Merge branch 'main' into justinchu/migrate-irbuilder
justinchuby Jul 25, 2025
b04a455
wip
justinchuby Jul 25, 2025
e4c9576
wip
justinchuby Jul 25, 2025
11a735e
refactor
justinchuby Jul 29, 2025
0cd1f20
_ValueEnvironment
justinchuby Jul 30, 2025
d98e8b5
emit_const
justinchuby Jul 30, 2025
38b1c03
Merge branch 'main' into justinchu/migrate-irbuilder
justinchuby Aug 1, 2025
f1c10c6
_ValueEnvironment
justinchuby Aug 16, 2025
ce2428f
fixme
justinchuby Aug 16, 2025
7f2fd2b
Merge branch 'main' into justinchu/migrate-irbuilder
justinchuby Oct 27, 2025
d95fa48
Merge branch 'main' into justinchu/migrate-irbuilder
justinchuby Oct 29, 2025
bba0ff6
Merge
justinchuby Oct 29, 2025
7a0b016
WIP
justinchuby Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
917 changes: 509 additions & 408 deletions onnxscript/converter.py → onnxscript/_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import onnxscript
import onnxscript.testing
from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor
from onnxscript import BOOL, FLOAT, INT64, _converter, graph, script, tensor
from onnxscript.onnx_opset import opset11 as op11
from onnxscript.onnx_opset import opset15 as op
from tests.common import onnx_script_test_case, testutils
Expand Down Expand Up @@ -437,12 +437,12 @@
global_names = globals().copy()
top_level_ast = ast.parse(source)
f_ast = top_level_ast.body[0]
cvt = converter.Converter(
cvt = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=op, global_names=global_names, source=source, default_opset=op
)
try:
cvt.translate_function_def(f_ast)
except converter.TranslationError as e:
except _converter.TranslationError as e:
if msg not in str(e):
raise AssertionError(f"Unable to find {msg!r} in {e!r} in\n{source}") from e
return
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/_internal/analysis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Analysis utilities for Python AST."""
from __future__ import annotations

import ast
from typing import Any, Optional, Sequence, Set
from typing import Any, Optional, Sequence

from onnxscript import sourceinfo
from onnxscript._internal import ast_utils
Expand All @@ -15,7 +16,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
return for_stmt.target.id


def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
def _used_vars(expr: Optional[ast.expr]) -> set[str]:
"""Return set of all variables used, including function names, in an expression."""
if expr is None:
return set()
Expand All @@ -35,7 +36,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
return result


def _lhs_vars(lhs: ast.expr) -> Set[str]:
def _lhs_vars(lhs: ast.expr) -> set[str]:
"""Return set of assigned variables in the lhs of an assignment statement."""

def get_id(e):
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from onnxscript import ir, tensor

if TYPE_CHECKING:
from onnxscript import converter
from onnxscript import _converter

# Conversions from python values to ONNX are used by both the script converter as well
# as the eager-mode runtime and both need to be consistent. The script converter converts
Expand Down Expand Up @@ -187,24 +187,24 @@


def static_cast_inputs(
converter_: converter.Converter,
converter_: _converter.Converter,
op_schema: Optional[OpSchema],
args: Sequence[Optional[converter.Variable]],
args: Sequence[Optional[_converter.Variable]],

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> tuple[str, ...]:
"""Used for autocast during script-translation.
This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))"
Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed.
"""

def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]:
def get_type_info(x: Optional[_converter.Variable]) -> Optional[_converter.Variable]:

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
"""Returns x back if x can serve as the target-type for a cast (as the second
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
castable, while X can serve as the target-type.
"""
return None if x is None or x.is_castable else x

def cast_like(
x: Optional[converter.Variable], y: Optional[converter.Variable]
x: Optional[_converter.Variable], y: Optional[_converter.Variable]

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> Optional[str]:
if x is None:
return None
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/ir/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _is_optional(type_: type) -> bool:
return False


def _get_attr_type(type_: type) -> ir.AttributeType:
def get_attr_type(type_: type) -> ir.AttributeType:
"""Obtain the type of the attribute from a Python class."""
try:
if type_ in _PY_TYPE_TO_ATTR_TYPE:
Expand Down Expand Up @@ -455,7 +455,7 @@ def from_function(
)
else:
type_ = type_hints[param.name]
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
# Construct the default attribute
if param.default is not inspect.Parameter.empty:
# TODO: Use ir_convenience instead to handle int as float
Expand Down
13 changes: 13 additions & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ def attr_proto(self) -> onnx.AttributeProto:


class IRStmt:
"""An IR statement (representing an operation).

Details:
- `result`: A sequence of variable names that this statement assigns to.
- `callee`: The operation being called, represented as an instance of `values.Op`.
- `args`: A sequence of arguments to the operation, which can be variable names or
`None` for optional arguments.
- `attrs`: A sequence of attributes for the operation, represented as `IRAttributeValue`
instances.
- `sub_functions`: A dictionary of sub-functions that this statement may call, mapping
function names to `onnx.FunctionProto` instances.
"""

def __init__(
self,
result: Sequence[str],
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import ParamSpec

import onnxscript
from onnxscript import converter, ir, irbuilder, values
from onnxscript import _converter, ir, irbuilder, values
from onnxscript._internal import ast_utils

_R = TypeVar("_R")
Expand All @@ -29,7 +29,7 @@
# See if conversion succeeds.
# TODO: cleanup Converter interface/API, separating checker from
# converter
convert = converter.Converter(
convert = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=opset,
global_names=global_names,
source=source,
Expand Down
51 changes: 16 additions & 35 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Sequence, Union

import onnx
import onnx_ir as ir

from onnxscript import onnx_types

Expand All @@ -24,35 +25,35 @@

# Map from python type to corresponding ONNX AttributeProto type
_PYTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOAT,
int: onnx.AttributeProto.INT,
str: onnx.AttributeProto.STRING,
bool: onnx.AttributeProto.INT, # experimental
float: ir.AttributeType.FLOAT,
int: ir.AttributeType.INT,
str: ir.AttributeType.STRING,
bool: ir.AttributeType.INT, # experimental
}

# Map from python type to corresponding ONNX AttributeProto type,
# for repeated (i.e., list of) values
_LISTTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOATS,
int: onnx.AttributeProto.INTS,
str: onnx.AttributeProto.STRINGS,
bool: onnx.AttributeProto.INTS, # experimental
float: ir.AttributeType.FLOATS,
int: ir.AttributeType.INTS,
str: ir.AttributeType.STRINGS,
bool: ir.AttributeType.INTS, # experimental
}

_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])

# Map from ONNX AttributeProto type to its representation (in ONNX Script).
_ATTRTYPE_TO_REPR = {
onnx.AttributeProto.FLOAT: "float",
onnx.AttributeProto.INT: "int",
onnx.AttributeProto.STRING: "str",
onnx.AttributeProto.FLOATS: "Sequence[float]",
onnx.AttributeProto.INTS: "Sequence[int]",
onnx.AttributeProto.STRINGS: "Sequence[str]",
ir.AttributeType.FLOAT: "float",
ir.AttributeType.INT: "int",
ir.AttributeType.STRING: "str",
ir.AttributeType.FLOATS: "Sequence[float]",
ir.AttributeType.INTS: "Sequence[int]",
ir.AttributeType.STRINGS: "Sequence[str]",
}


def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str:
def onnx_attr_type_to_onnxscript_repr(attr_type: ir.AttributeType) -> str:
if attr_type not in _ATTRTYPE_TO_REPR:
supported = ", ".join(
f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR
Expand Down Expand Up @@ -95,26 +96,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP


def pytype_to_attrtype(
pytype: TypeAnnotationValue,
) -> Optional[onnx.AttributeProto.AttributeType]:
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
type_constructor = typing.get_origin(pytype)
# Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
if type_constructor is typing.Union:
# Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
args = [x for x in typing.get_args(pytype) if x is not type(None)]
if len(args) == 1:
return pytype_to_attrtype(args[0])
if type_constructor in _LIST_CONSTRUCTORS:
elt_type = typing.get_args(pytype)[0]
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
return None


def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
"""Returns True if base type of pytype is bool, False otherwise."""
pytype = _remove_annotation(pytype)
Expand Down
94 changes: 2 additions & 92 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging
import types
import typing
from enum import IntFlag
from typing import ( # type: ignore[attr-defined]
Any,
Callable,
Expand All @@ -18,15 +17,13 @@
Protocol,
Sequence,
TypeVar,
_GenericAlias,
)

import onnx
import onnx.defs
from typing_extensions import ParamSpec

from onnxscript import converter as converter_module
from onnxscript import irbuilder, sourceinfo, type_annotation
from onnxscript import _converter, irbuilder, type_annotation
from onnxscript._internal import ast_utils, deprecation
from onnxscript.ir import _schemas

Expand Down Expand Up @@ -638,7 +635,7 @@
closure = inspect.getclosurevars(self.func)
global_names = module.__dict__.copy()
global_names.update(closure.nonlocals)
converter = converter_module.Converter(
converter = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=self._opset,
global_names=global_names,
source=src,
Expand Down Expand Up @@ -686,90 +683,3 @@
# argument order from the Python function definition, which is lost in OpSchema.
self._param_schemas = _param_schemas_from_function_ir(self.function_ir)
return self._param_schemas


class SymbolValue:
"""Represents script-time value information about named variables used in a script.
At translation-time, the (local) variables of a script, including its parameters,
are bound to a SymbolValue.
SymbolValues fall into the following categories:
AttrRef: Function parameters of attribute-kind, also mapped to ONNX attributes
Dynamic: values computed at runtime (of tensor type, for now) mapped to NodeArgs.
Dynamic values include input-parameters of the script, as well intermediate
values computed in the script.
For example, consider the following script definition:
::
@script()
def ThresholdedRelu(X, alpha: float):
zero = op.CastLike(0, X)
return op.Where(X > alpha, X, zero)
Here, `X` has a Dynamic value, `alpha` has an AttrRef value, and `zero`
has a Dynamic value.
Scripts may also contain references to global variables, but the translator
does not associate a SymbolValue with them. The python value of global variables
is used directly in the translation, and such global variables are intended
to be used for limited purposes, namely:
* To identify an opset
* To represent constant-values, translated into ONNX constants.
"""

def __init__(self, info: sourceinfo.SourceInfo) -> None:
if not isinstance(info, sourceinfo.SourceInfo):
raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.")
self.info = info


class AttrRef(SymbolValue):
def __init__(
self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo
) -> None:
"""Initializes AttrRef.
Arguments:
attr_name: name of the attribute-parameter
typeinfo: type annotation of the attribute.
op's attributes in ONNX are usually single type or list of single type.
info: for debugging use.
"""
super().__init__(info)
self.value = attr_name
self.typeinfo = typeinfo
if not isinstance(typeinfo, (type, _GenericAlias)):
# typing._GenericAlias for List[int] and List[str], etc.
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
self.typeinfo = typeinfo


class DynamicKind(IntFlag):
Unknown = 0
Input = 1
Output = 2
Intermediate = 4
Loop = 8


class Dynamic(SymbolValue):
def __init__(
self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None
) -> None:
"""Initializes Dynamic.
Arguments:
onnx_var: the name of the ONNX variable used to represent this value
kind: the DynamicKind of this variable
info: source-location information for error-messages/debugging
typeinfo: type-information for the value
"""
super().__init__(info)
assert isinstance(kind, DynamicKind)
self.value = onnx_var
self.kind = kind
self.typeinfo = typeinfo
Loading