Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
596 changes: 341 additions & 255 deletions onnxscript/converter.py → onnxscript/_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import onnxscript
import onnxscript.testing
from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor
from onnxscript import BOOL, FLOAT, INT64, _converter, graph, script, tensor
from onnxscript.onnx_opset import opset11 as op11
from onnxscript.onnx_opset import opset15 as op
from tests.common import onnx_script_test_case, testutils
Expand Down Expand Up @@ -437,12 +437,12 @@
global_names = globals().copy()
top_level_ast = ast.parse(source)
f_ast = top_level_ast.body[0]
cvt = converter.Converter(
cvt = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=op, global_names=global_names, source=source, default_opset=op
)
try:
cvt.translate_function_def(f_ast)
except converter.TranslationError as e:
except _converter.TranslationError as e:
if msg not in str(e):
raise AssertionError(f"Unable to find {msg!r} in {e!r} in\n{source}") from e
return
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from onnxscript import ir, tensor

if TYPE_CHECKING:
from onnxscript import converter
from onnxscript import _converter

# Conversions from python values to ONNX are used by both the script converter as well
# as the eager-mode runtime and both need to be consistent. The script converter converts
Expand Down Expand Up @@ -187,24 +187,24 @@ def get_type_info(x):


def static_cast_inputs(
converter_: converter.Converter,
converter_: _converter.Converter,
op_schema: Optional[OpSchema],
args: Sequence[Optional[converter.Variable]],
args: Sequence[Optional[_converter.Variable]],

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> tuple[str, ...]:
"""Used for autocast during script-translation.
This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))"
Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed.
"""

def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]:
def get_type_info(x: Optional[_converter.Variable]) -> Optional[_converter.Variable]:

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
"""Returns x back if x can serve as the target-type for a cast (as the second
argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
castable, while X can serve as the target-type.
"""
return None if x is None or x.is_castable else x

def cast_like(
x: Optional[converter.Variable], y: Optional[converter.Variable]
x: Optional[_converter.Variable], y: Optional[_converter.Variable]

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_converter.Variable" is not defined To disable, use # type: ignore[name-defined]
) -> Optional[str]:
if x is None:
return None
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/ir/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import collections.abc
import copy

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused import copy (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

copy imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
import dataclasses
import inspect
import logging
Expand Down Expand Up @@ -210,7 +211,7 @@
return False


def _get_attr_type(type_: type) -> ir.AttributeType:
def get_attr_type(type_: type) -> ir.AttributeType:
"""Obtain the type of the attribute from a Python class."""
try:
if type_ in _PY_TYPE_TO_ATTR_TYPE:
Expand Down Expand Up @@ -455,7 +456,7 @@
)
else:
type_ = type_hints[param.name]
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
if (attr_type := get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
# Construct the default attribute
if param.default is not inspect.Parameter.empty:
# TODO: Use ir_convenience instead to handle int as float
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ def attr_proto(self) -> onnx.AttributeProto:


class IRStmt:
"""An IR statement (representing an operation).

Details:
- `result`: A sequence of variable names that this statement assigns to.
- `callee`: The operation being called, represented as an instance of `values.Op`.
- `args`: A sequence of arguments to the operation, which can be variable names or
`None` for optional arguments.
- `attrs`: A sequence of attributes for the operation, represented as `IRAttributeValue`
instances.
- `sub_functions`: A dictionary of sub-functions that this statement may call, mapping
function names to `onnx.FunctionProto` instances.
"""
def __init__(
self,
result: Sequence[str],
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import ParamSpec

import onnxscript
from onnxscript import converter, ir, irbuilder, values
from onnxscript import _converter, ir, irbuilder, values
from onnxscript._internal import ast_utils

_R = TypeVar("_R")
Expand All @@ -29,7 +29,7 @@
# See if conversion succeeds.
# TODO: cleanup Converter interface/API, separating checker from
# converter
convert = converter.Converter(
convert = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=opset,
global_names=global_names,
source=source,
Expand Down
51 changes: 16 additions & 35 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Sequence, Union

import onnx
import onnx_ir as ir

from onnxscript import onnx_types
from onnxscript._internal import version_utils
Expand All @@ -25,35 +26,35 @@

# Map from python type to corresponding ONNX AttributeProto type
_PYTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOAT,
int: onnx.AttributeProto.INT,
str: onnx.AttributeProto.STRING,
bool: onnx.AttributeProto.INT, # experimental
float: ir.AttributeType.FLOAT,
int: ir.AttributeType.INT,
str: ir.AttributeType.STRING,
bool: ir.AttributeType.INT, # experimental
}

# Map from python type to corresponding ONNX AttributeProto type,
# for repeated (i.e., list of) values
_LISTTYPE_TO_ATTRTYPE_MAP = {
float: onnx.AttributeProto.FLOATS,
int: onnx.AttributeProto.INTS,
str: onnx.AttributeProto.STRINGS,
bool: onnx.AttributeProto.INTS, # experimental
float: ir.AttributeType.FLOATS,
int: ir.AttributeType.INTS,
str: ir.AttributeType.STRINGS,
bool: ir.AttributeType.INTS, # experimental
}

_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])

# Map from ONNX AttributeProto type to its representation (in ONNX Script).
_ATTRTYPE_TO_REPR = {
onnx.AttributeProto.FLOAT: "float",
onnx.AttributeProto.INT: "int",
onnx.AttributeProto.STRING: "str",
onnx.AttributeProto.FLOATS: "Sequence[float]",
onnx.AttributeProto.INTS: "Sequence[int]",
onnx.AttributeProto.STRINGS: "Sequence[str]",
ir.AttributeType.FLOAT: "float",
ir.AttributeType.INT: "int",
ir.AttributeType.STRING: "str",
ir.AttributeType.FLOATS: "Sequence[float]",
ir.AttributeType.INTS: "Sequence[int]",
ir.AttributeType.STRINGS: "Sequence[str]",
}


def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeType) -> str:
def onnx_attr_type_to_onnxscript_repr(attr_type: ir.AttributeType) -> str:
if attr_type not in _ATTRTYPE_TO_REPR:
supported = ", ".join(
f"'{onnx.AttributeProto.AttributeType.Name(v)}'" for v in _ATTRTYPE_TO_REPR
Expand Down Expand Up @@ -87,26 +88,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP


def pytype_to_attrtype(
pytype: TypeAnnotationValue,
) -> Optional[onnx.AttributeProto.AttributeType]:
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
type_constructor = typing.get_origin(pytype)
# Remove Optional wrapper if present, which is represented as an Union[..., type(None)]
if type_constructor is typing.Union:
# Filter out type(None), since typing.Optional[X] evaluates to Union[X, type(None)]
args = [x for x in typing.get_args(pytype) if x is not type(None)]
if len(args) == 1:
return pytype_to_attrtype(args[0])
if type_constructor in _LIST_CONSTRUCTORS:
elt_type = typing.get_args(pytype)[0]
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
return None


def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
"""Returns True if base type of pytype is bool, False otherwise."""
pytype = _remove_annotation(pytype)
Expand Down
8 changes: 3 additions & 5 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
import onnx.defs
from typing_extensions import ParamSpec

from onnxscript import converter as converter_module
from onnxscript import irbuilder, sourceinfo, type_annotation
from onnxscript import _converter, irbuilder, sourceinfo, type_annotation

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused sourceinfo imported from onnxscript (unused-import)
See unused-import. To disable, use # pylint: disable=unused-import

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

onnxscript.sourceinfo imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
from onnxscript._internal import ast_utils, deprecation
from onnxscript.ir import _schemas

Expand Down Expand Up @@ -638,7 +637,7 @@
closure = inspect.getclosurevars(self.func)
global_names = module.__dict__.copy()
global_names.update(closure.nonlocals)
converter = converter_module.Converter(
converter = _converter.Converter(

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'root' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
opset=self._opset,
global_names=global_names,
source=src,
Expand Down Expand Up @@ -741,13 +740,12 @@
"""
super().__init__(info)
self.value = attr_name
self.typeinfo = typeinfo

if not isinstance(typeinfo, (type, _GenericAlias)):
# typing._GenericAlias for List[int] and List[str], etc.
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
self.typeinfo = typeinfo


class DynamicKind(IntFlag):
Unknown = 0
Input = 1
Expand Down
Loading