Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 15 additions & 35 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4261,7 +4261,7 @@ def aten_index_copy(
raise NotImplementedError()


@torch_op(("aten::index_put", "aten::_unsafe_index_put"))
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
Expand All @@ -4275,18 +4275,18 @@ def aten_index_put(
"""

# TODO(justinchuby): Handle when indicies has more than one element
index = op.SequenceAt(indices, 0)
index = indices[0]
new_index = op.Unsqueeze(index, [-1])

if op.Cast(accumulate, to=BOOL.dtype):
if accumulate:
result = op.ScatterND(self, new_index, values, reduction="add")
else:
result = op.ScatterND(self, new_index, values)

return result


@torch_op("aten::index_put")
@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
Expand All @@ -4295,37 +4295,17 @@ def aten_index_put_bool(
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return op.Identity(self)
if op.ReduceSum(index_int) == 0:
result = self
else:
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Shape(self, start=1, end=2)
index_dim_0 = op.Shape(index, start=0, end=1)
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

# values must have same rank with input(self)
if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator]
values = op.Unsqueeze(values, op.Constant(value_ints=[0]))

if op.Cast(accumulate, to=BOOL.dtype):
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)

return result
# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
Comment thread
xadupre marked this conversation as resolved.
# Reshape indices so it can be properly broadcasted
lself, lindex = len(self.shape), len(index.shape)
if lself > lindex:
shape = op.Shape(index)
append = op.Constant(value_ints=[1 for _ in range(lself - lindex)])
new_shape = op.Concat(shape, append, axis=0)
index = op.Reshape(index, new_shape)
return op.Where(index, values, self)
Comment thread
justinchuby marked this conversation as resolved.
Comment thread
justinchuby marked this conversation as resolved.
Comment thread
justinchuby marked this conversation as resolved.
Outdated


def aten_index_reduce(
Expand Down
83 changes: 83 additions & 0 deletions tests/function_libs/torch_lib/aten_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import onnxruntime
import torch


class TestOnnxExportAten(unittest.TestCase):
def test_aten_index_put_mask_bool_fixed_broadcast_2d(self):
class Model(torch.nn.Module):
def forward(self, x, values):
x = x.clone()
mask = torch.tensor([True, False, True, True, False]).to(torch.bool)
Comment thread
justinchuby marked this conversation as resolved.
Outdated
x[mask] = values
return x

model = Model()
xs = (
torch.arange(25).reshape((5, 5)).to(torch.float32),
torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32),
)
expected = model(*xs)
ep = torch.onnx.export(model, xs, dynamo=True)
sess = onnxruntime.InferenceSession(
ep.model_proto.SerializeToString(),
providers=["CPUExecutionProvider"],
)
feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs]))
got = sess.run(None, feeds)[0]
torch.testing.assert_close(expected, torch.from_numpy(got))

def test_aten_index_put_mask_bool_fixed_broadcast_3d(self):
class Model(torch.nn.Module):
def forward(self, x, values):
x = x.clone()
mask = torch.tensor([True, False]).to(torch.bool)
x[mask] = values
return x
# return torch.ops.aten.index_put(x, (mask,), values)

model = Model()
xs = (
torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32),
torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32),
)
expected = model(*xs)
ep = torch.onnx.export(model, xs, dynamo=True)
sess = onnxruntime.InferenceSession(
ep.model_proto.SerializeToString(),
providers=["CPUExecutionProvider"],
)
feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs]))
got = sess.run(None, feeds)[0]
torch.testing.assert_close(expected, torch.from_numpy(got))

def test_aten_index_put_mask_bool_fixed_broadcast_3d_2(self):
class Model(torch.nn.Module):
def forward(self, x, values):
x = x.clone()
mask = torch.tensor([[True, False, False], [True, True, False]]).to(torch.bool)
x[mask] = values
return x
# return torch.ops.aten.index_put(x, (mask,), values)

model = Model()
xs = (
torch.arange(2 * 3 * 5).reshape((2, 3, 5)).to(torch.float32),
torch.tensor([700, 800, 900, 1000, 1100], dtype=torch.float32),
)
expected = model(*xs)
ep = torch.onnx.export(model, xs, dynamo=True)
sess = onnxruntime.InferenceSession(
ep.model_proto.SerializeToString(),
providers=["CPUExecutionProvider"],
)
feeds = dict(zip([i.name for i in sess.get_inputs()], [x.numpy() for x in xs]))
got = sess.run(None, feeds)[0]
torch.testing.assert_close(expected, torch.from_numpy(got))


if __name__ == "__main__":
unittest.main(verbosity=2)
6 changes: 2 additions & 4 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,10 @@ def _where_input_wrangler(
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
)
.skip(
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
)
.skip(reason="FIXME: https://github.com/microsoft/onnxscript/issues/1749"),
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
Expand Down