Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.sum(relax.op.abs(data), axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.sum(relax.op.abs(data), None, keepdims)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over specified axes
else:
return relax.op.sum(relax.op.abs(data), axes, keepdims)


class ReduceL2(OnnxOpConverter):
"""Converts an onnx ReduceL2 node into an equivalent Relax expression."""
Expand Down
153 changes: 139 additions & 14 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,24 +1503,24 @@ def verify_embedlayernormalization(
)


def create_reduce_test_parameters():
def create_reduce_test_parameters_axes_attr():
output = []
for value in [True, False]:
output.append(("ReduceMax", value))
output.append(("ReduceMean", value))
output.append(("ReduceMin", value))
output.append(("ReduceProd", value))
output.append(("ReduceSum", value))
output.append(("ReduceSumSquare", value))
output.append(("ReduceLogSum", value))
output.append(("ReduceLogSumExp", value))
output.append(("ReduceL1", value))
output.append(("ReduceL2", value))
output.append(("ReduceMax", value, 11))
output.append(("ReduceMean", value, 13))
output.append(("ReduceMin", value, 11))
output.append(("ReduceProd", value, 13))
output.append(("ReduceSum", value, 11))
output.append(("ReduceSumSquare", value, 13))
output.append(("ReduceLogSum", value, 13))
output.append(("ReduceLogSumExp", value, 13))
output.append(("ReduceL1", value, 13))
output.append(("ReduceL2", value, 13))
return output


@pytest.mark.parametrize("func, dynamic", create_reduce_test_parameters())
def test_all_reduce_funcs(func, dynamic):
@pytest.mark.parametrize("func, dynamic, opset", create_reduce_test_parameters_axes_attr())
def test_all_reduce_funcs_axes_attr(func, dynamic, opset):
def verify_reduce_func(func, data, axis, keepdims):
inshape = data.shape
outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def verify_reduce_func(func, data, axis, keepdims):

inputs_dict = {"x": data}
# Reduction ops accumulate arithmetic errors, so we use a higher tolerance.
check_correctness(model, inputs_dict, opset=11, rtol=1e-4, atol=1e-4)
check_correctness(model, inputs_dict, opset=opset, rtol=1e-4, atol=1e-4)

for keepdims in [True, False]:
verify_reduce_func(
Expand Down Expand Up @@ -1577,6 +1577,131 @@ def verify_reduce_func(func, data, axis, keepdims):
)


def create_reduce_test_parameters_axes_input():
output = []
for dynamic in [True, False]:
# TODO(@vacu9708): Enable the tests after implementing other reduce ops
# output.append(("ReduceMax", dynamic, 20))
# output.append(("ReduceMean", dynamic, 18))
# output.append(("ReduceMin", dynamic, 20))
# output.append(("ReduceProd", dynamic, 18))
# output.append(("ReduceSum", dynamic, 13))
# output.append(("ReduceSumSquare", dynamic, 18))
# output.append(("ReduceLogSum", dynamic, 18))
# output.append(("ReduceLogSumExp", dynamic, 18))
output.append(("ReduceL1", dynamic, 18))
# output.append(("ReduceL2", dynamic, 18))
return output


@pytest.mark.parametrize("func, dynamic, opset", create_reduce_test_parameters_axes_input())
def test_all_reduce_funcs_axes_input(func, dynamic, opset):
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
inshape = data.shape

inputs = ["x"]
initializers = []

# Optional `axes` input
if axes is not None:
axes_name = "reduce_axes"
axes_np = np.asarray(axes, dtype=np.int64)
axes_init = helper.make_tensor(
name=axes_name,
data_type=TensorProto.INT64,
dims=axes_np.shape,
vals=axes_np,
)
initializers.append(axes_init)
inputs.append(axes_name)

# Determine input and output shapes
if not axes and not noop_with_empty_axes:
outshape = np.sum(data, axis=None, keepdims=keepdims).shape
elif not axes and noop_with_empty_axes:
outshape = inshape
else:
outshape = np.sum(data, axis=axes, keepdims=keepdims).shape

if dynamic:
in_list = ["?"] * len(inshape)
out_list = ["?"] * len(outshape)
else:
in_list = list(inshape)
out_list = list(outshape)

# Make a model node
node = helper.make_node(
func,
inputs=inputs,
outputs=["y"],
keepdims=keepdims,
noop_with_empty_axes=noop_with_empty_axes,
)

# Make a model graph and a model
graph = helper.make_graph(
[node],
"reduce18_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, in_list)],
initializer=initializers,
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_list)],
)
model = helper.make_model(graph, producer_name="reduce18_test")

# Run TVM importer vs onnxruntime
inputs_dict = {"x": data}
check_correctness(model, inputs_dict, opset=opset, rtol=1e-4, atol=1e-4)

# Verify
for keepdims in [True, False]:
# no `axes` input && `noop_with_empty_axes` = 0 -> reduce over all dimensions.
verify_reduce_func(
func,
np.random.randn(3, 2, 2).astype(np.float32),
axes=[],
keepdims=keepdims,
noop_with_empty_axes=False,
)

# no `axes` input && `noop_with_empty_axes` = 0 -> reduce over all dimensions.
verify_reduce_func(
func,
np.random.randn(3, 2, 2).astype(np.float32),
axes=None,
keepdims=keepdims,
noop_with_empty_axes=False,
)

# no `axes` input && `noop_with_empty_axes` = 1 -> return the input unchanged.
verify_reduce_func(
func,
np.random.randn(4, 3).astype(np.float32),
axes=[],
keepdims=keepdims,
noop_with_empty_axes=True,
)

# no `axes` input && `noop_with_empty_axes` = 1 -> return the input unchanged.
# (onnxruntime bug) Runtime error on the onnxruntime part
# verify_reduce_func(
# func,
# np.random.randn(4, 3).astype(np.float32),
# axes=None,
# keepdims=keepdims,
# noop_with_empty_axes=True,
# )

# `axes` provided -> reduce over specified axes.
verify_reduce_func(
func,
np.random.randn(3, 3, 3, 1).astype(np.float32),
axes=(1, 2),
keepdims=keepdims,
noop_with_empty_axes=True,
)


@pytest.mark.parametrize("in_dtype", [np.float32, np.int32])
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
@pytest.mark.parametrize("keepdims", [None, True, False])
Expand Down