Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]:
for op in registration.default_registry.values():
for func in (*op.overloads, *op.privates, *op.complex):
for func in (*op.overloads, *op.complex):
if isinstance(func, onnxscript.OnnxFunction):
yield func

Expand Down
26 changes: 18 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3704,7 +3704,6 @@
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand All @@ -3730,7 +3729,6 @@
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand Down Expand Up @@ -4060,7 +4058,9 @@


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
def aten_index(
self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]]
) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
NOTE: Understanding `aten::index`
Expand All @@ -4080,14 +4080,19 @@
None in `indices` are like fillers for dimensions that cannot be removed in the process.
"""
# Handle Boolean indexing first
for index in indices:
if index is None:
continue
if index.dtype == BOOL.dtype:
return _aten_index_bool(self, indices)

index_ranks = [len(index.shape) for index in indices if index is not None]

return _aten_index_onnx(self, indices, index_ranks)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
index_ranks = [len(index.shape) for index in indices if index is not None]

if index_ranks[0] == 1:
Expand Down Expand Up @@ -4146,7 +4151,7 @@
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
indices: Sequence[Optional[Union[INT64, BOOL]]],
values: TReal,
accumulate: bool = False,
) -> TReal:
Expand All @@ -4155,6 +4160,12 @@
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
# Handle Boolean indexing first
for index in indices:
if index is None:
continue
if index.dtype == BOOL.dtype:
return _aten_index_put_bool(self, indices, values, accumulate=accumulate)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand Down Expand Up @@ -4232,8 +4243,7 @@
return result


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
values: TReal,
Expand Down
1 change: 0 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def aten_col2im(
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding

# Only one ONNX op here so didn't write a private function
return op.Col2Im(
self,
output_size,
Expand Down
29 changes: 17 additions & 12 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ class OverloadedFunction:
Attributes:
name: Name of the op. E.g. "aten::add".
overloads: Overloads function.
privates: Private functions not exposed to users.
complex: Support complex functions.
"""

def __init__(self, name: str):
self.name = name
self.overloads: list[Any] = []
self.privates: list[Any] = []
self.complex: list[Any] = []


Expand All @@ -39,17 +37,22 @@ class Registry:
def __init__(self):
self._registry: dict[str, OverloadedFunction] = {}

def register(
self, func: Any, name: str, *, private: bool = False, complex: bool = False
) -> None:
def register(self, func: Any, name: str, *, complex: bool = False) -> None:
"""Register a function."""

if private:
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
elif complex:
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
overloaded_function = self._registry.setdefault(name, OverloadedFunction(name))

if complex:
if overloaded_function.complex:
raise ValueError(
f"Complex overload for '{name}' already registered: {overloaded_function.complex}."
)
overloaded_function.complex.append(func)
else:
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
if overloaded_function.overloads:
raise ValueError(
f"Real overload for '{name}' already registered: {overloaded_function.overloads}."
)
overloaded_function.overloads.append(func)

def __getitem__(self, name):
return self._registry[name]
Expand Down Expand Up @@ -131,7 +134,9 @@ def wrapper(

assert registry is not None
for name_ in _check_and_normalize_names(name):
registry.register(processed_func, name_, private=private, complex=complex)
if private:
continue
registry.register(processed_func, name_, complex=complex)
return processed_func

return wrapper
18 changes: 2 additions & 16 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,23 +721,10 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
input_wrangler=_index_put_input_wrangler,
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index),
TorchLibOpInfo(
"index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler
)
.skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
).skip(
dtypes=(torch.float16,),
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
Expand Down Expand Up @@ -1806,7 +1793,6 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
Expand Down
Loading