Skip to content

Commit 437d00a

Browse files
authored
[Relax][ONNX] Update Reduce ops to support axes as input (#18090)
- Support axes as an input for Reduce ops (e.g., ReduceL2, ReduceMax, …) - Add corresponding test cases
1 parent 23bcbc5 commit 437d00a

File tree

2 files changed

+222
-17
lines changed

2 files changed

+222
-17
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 212 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,6 +2509,29 @@ def _impl_v11(cls, bb, inputs, attr, params):
25092509
keepdims = attr.get("keepdims", 1)
25102510
return relax.op.max(data, axes, keepdims)
25112511

2512+
@classmethod
2513+
def _impl_v18(cls, bb, inputs, attr, params):
2514+
data = inputs[0]
2515+
keepdims = attr.get("keepdims", 1)
2516+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2517+
2518+
# Optional axes input
2519+
axes = None
2520+
if len(inputs) > 1 and inputs[1] is not None:
2521+
axes_const = get_constant(inputs[1], params)
2522+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2523+
axes = axes_const.data.numpy().tolist()
2524+
2525+
# If axes is empty and noop_with_empty_axes is False, reduce all dims
2526+
if not axes and not noop_with_empty_axes:
2527+
return relax.op.max(data, None, keepdims)
2528+
# If axes is empty and noop_with_empty_axes is True, return input unchanged
2529+
elif not axes and noop_with_empty_axes:
2530+
return data
2531+
# Otherwise reduce over specified axes
2532+
else:
2533+
return relax.op.max(data, axes, keepdims)
2534+
25122535

25132536
class ReduceMin(OnnxOpConverter):
25142537
"""Converts an onnx ReduceMin node into an equivalent Relax expression."""
@@ -2520,6 +2543,29 @@ def _impl_v11(cls, bb, inputs, attr, params):
25202543
keepdims = attr.get("keepdims", 1)
25212544
return relax.op.min(data, axes, keepdims)
25222545

2546+
@classmethod
2547+
def _impl_v18(cls, bb, inputs, attr, params):
2548+
data = inputs[0]
2549+
keepdims = attr.get("keepdims", 1)
2550+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2551+
2552+
# Optional axes input
2553+
axes = None
2554+
if len(inputs) > 1 and inputs[1] is not None:
2555+
axes_const = get_constant(inputs[1], params)
2556+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2557+
axes = axes_const.data.numpy().tolist()
2558+
2559+
# If axes is empty and noop_with_empty_axes is False, reduce all dims
2560+
if not axes and not noop_with_empty_axes:
2561+
return relax.op.min(data, None, keepdims)
2562+
# If axes is empty and noop_with_empty_axes is True, return input unchanged
2563+
elif not axes and noop_with_empty_axes:
2564+
return data
2565+
# Otherwise reduce over specified axes
2566+
else:
2567+
return relax.op.min(data, axes, keepdims)
2568+
25232569

25242570
class ReduceSum(OnnxOpConverter):
25252571
"""Converts an onnx ReduceSum node into an equivalent Relax expression."""
@@ -2534,11 +2580,25 @@ def _impl_v11(cls, bb, inputs, attr, params):
25342580
@classmethod
25352581
def _impl_v13(cls, bb, inputs, attr, params):
25362582
data = inputs[0]
2537-
axes = inputs[1]
25382583
keepdims = attr.get("keepdims", 1)
2539-
assert isinstance(axes, relax.Constant), "Only constant axes currently supported."
2540-
axes = axes.data.numpy().tolist()
2541-
return relax.op.sum(data, axes, keepdims)
2584+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2585+
2586+
# Optional axes input
2587+
axes = None
2588+
if len(inputs) > 1 and inputs[1] is not None:
2589+
axes_const = get_constant(inputs[1], params)
2590+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2591+
axes = axes_const.data.numpy().tolist()
2592+
2593+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2594+
if not axes and not noop_with_empty_axes:
2595+
return relax.op.sum(data, None, keepdims)
2596+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2597+
elif not axes and noop_with_empty_axes:
2598+
return data
2599+
# If axes is provided, reduce over the specified axes
2600+
else:
2601+
return relax.op.sum(data, axes, keepdims)
25422602

25432603

25442604
class ReduceMean(OnnxOpConverter):
@@ -2551,6 +2611,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25512611
keepdims = attr.get("keepdims", 1)
25522612
return relax.op.mean(data, axes, keepdims)
25532613

2614+
@classmethod
2615+
def _impl_v18(cls, bb, inputs, attr, params):
2616+
data = inputs[0]
2617+
keepdims = attr.get("keepdims", 1)
2618+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2619+
2620+
# Optional axes input
2621+
axes = None
2622+
if len(inputs) > 1 and inputs[1] is not None:
2623+
axes_const = get_constant(inputs[1], params)
2624+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2625+
axes = axes_const.data.numpy().tolist()
2626+
2627+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2628+
if not axes and not noop_with_empty_axes:
2629+
return relax.op.mean(data, None, keepdims)
2630+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2631+
elif not axes and noop_with_empty_axes:
2632+
return data
2633+
# If axes is provided, reduce over the specified axes
2634+
else:
2635+
return relax.op.mean(data, axes, keepdims)
2636+
25542637

25552638
class ReduceProd(OnnxOpConverter):
25562639
"""Converts an onnx ReduceProd node into an equivalent Relax expression."""
@@ -2562,6 +2645,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25622645
keepdims = attr.get("keepdims", 1)
25632646
return relax.op.prod(data, axes, keepdims)
25642647

2648+
@classmethod
2649+
def _impl_v18(cls, bb, inputs, attr, params):
2650+
data = inputs[0]
2651+
keepdims = attr.get("keepdims", 1)
2652+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2653+
2654+
# Optional axes input
2655+
axes = None
2656+
if len(inputs) > 1 and inputs[1] is not None:
2657+
axes_const = get_constant(inputs[1], params)
2658+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2659+
axes = axes_const.data.numpy().tolist()
2660+
2661+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2662+
if not axes and not noop_with_empty_axes:
2663+
return relax.op.prod(data, None, keepdims)
2664+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2665+
elif not axes and noop_with_empty_axes:
2666+
return data
2667+
# If axes is provided, reduce over the specified axes
2668+
else:
2669+
return relax.op.prod(data, axes, keepdims)
2670+
25652671

25662672
class ReduceLogSumExp(OnnxOpConverter):
25672673
"""Converts an onnx ReduceLogSumExp node into an equivalent Relax expression."""
@@ -2579,6 +2685,38 @@ def _impl_v13(cls, bb, inputs, attr, params):
25792685
out_x = relax.op.squeeze(out_x, axes)
25802686
return out_x
25812687

2688+
@classmethod
2689+
def _impl_v18(cls, bb, inputs, attr, params):
2690+
x = inputs[0]
2691+
keepdims = attr.get("keepdims", 1)
2692+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2693+
2694+
# Optional axes input (second input)
2695+
axes = None
2696+
if len(inputs) > 1 and inputs[1] is not None:
2697+
axes_const = get_constant(inputs[1], params)
2698+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2699+
axes = axes_const.data.numpy().tolist()
2700+
2701+
# Calculate LogSumExp
2702+
log_sum_exp = lambda axes: (
2703+
max_x := relax.op.max(x, axes, True),
2704+
exp_x := relax.op.exp(relax.op.subtract(x, max_x)),
2705+
sum_x := relax.op.sum(exp_x, axes, True),
2706+
out_x := relax.op.add(relax.op.log(sum_x), max_x),
2707+
relax.op.squeeze(out_x, axes) if not keepdims else out_x,
2708+
)[-1]
2709+
2710+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2711+
if not axes and not noop_with_empty_axes:
2712+
return log_sum_exp(None)
2713+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2714+
elif not axes and noop_with_empty_axes:
2715+
return x
2716+
# If axes is provided, reduce over the specified axes
2717+
else:
2718+
return log_sum_exp(axes)
2719+
25822720

25832721
class ReduceLogSum(OnnxOpConverter):
25842722
"""Converts an onnx ReduceLogSum node into an equivalent Relax expression."""
@@ -2590,6 +2728,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25902728
keepdims = attr.get("keepdims", 1)
25912729
return relax.op.log(relax.op.sum(data, axes, keepdims))
25922730

2731+
@classmethod
2732+
def _impl_v18(cls, bb, inputs, attr, params):
2733+
data = inputs[0]
2734+
keepdims = attr.get("keepdims", 1)
2735+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2736+
2737+
# Optional axes input
2738+
axes = None
2739+
if len(inputs) > 1 and inputs[1] is not None:
2740+
axes_const = get_constant(inputs[1], params)
2741+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2742+
axes = axes_const.data.numpy().tolist()
2743+
2744+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2745+
if not axes and not noop_with_empty_axes:
2746+
return relax.op.log(relax.op.sum(data, None, keepdims))
2747+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2748+
elif not axes and noop_with_empty_axes:
2749+
return data
2750+
# If axes is provided, reduce over the specified axes
2751+
else:
2752+
return relax.op.log(relax.op.sum(data, axes, keepdims))
2753+
25932754

25942755
class ReduceSumSquare(OnnxOpConverter):
25952756
"""Converts an onnx ReduceSumSquare node into an equivalent Relax expression."""
@@ -2601,6 +2762,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
26012762
keepdims = attr.get("keepdims", 1)
26022763
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)
26032764

2765+
@classmethod
2766+
def _impl_v18(cls, bb, inputs, attr, params):
2767+
data = inputs[0]
2768+
keepdims = attr.get("keepdims", 1)
2769+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2770+
2771+
# Optional axes input
2772+
axes = None
2773+
if len(inputs) > 1 and inputs[1] is not None:
2774+
axes_const = get_constant(inputs[1], params)
2775+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2776+
axes = axes_const.data.numpy().tolist()
2777+
2778+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2779+
if not axes and not noop_with_empty_axes:
2780+
return relax.op.sum(relax.op.multiply(data, data), None, keepdims)
2781+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2782+
elif not axes and noop_with_empty_axes:
2783+
return data
2784+
# If axes is provided, reduce over the specified axes
2785+
else:
2786+
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)
2787+
26042788

26052789
class ReduceL1(OnnxOpConverter):
26062790
"""Converts an onnx ReduceL1 node into an equivalent Relax expression."""
@@ -2631,7 +2815,7 @@ def _impl_v18(cls, bb, inputs, attr, params):
26312815
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
26322816
elif not axes and noop_with_empty_axes:
26332817
return data
2634-
# If axes is provided, reduce over specified axes
2818+
# If axes is provided, reduce over the specified axes
26352819
else:
26362820
return relax.op.sum(relax.op.abs(data), axes, keepdims)
26372821

@@ -2646,6 +2830,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
26462830
keepdims = attr.get("keepdims", 1)
26472831
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))
26482832

2833+
@classmethod
2834+
def _impl_v18(cls, bb, inputs, attr, params):
2835+
data = inputs[0]
2836+
keepdims = attr.get("keepdims", 1)
2837+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2838+
2839+
# Optional axes input
2840+
axes = None
2841+
if len(inputs) > 1 and inputs[1] is not None:
2842+
axes_const = get_constant(inputs[1], params)
2843+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2844+
axes = axes_const.data.numpy().tolist()
2845+
2846+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2847+
if not axes and not noop_with_empty_axes:
2848+
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), None, keepdims))
2849+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2850+
elif not axes and noop_with_empty_axes:
2851+
return data
2852+
# If axes is provided, reduce over the specified axes
2853+
else:
2854+
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))
2855+
26492856

26502857
class ArgMax(OnnxOpConverter):
26512858
"""Converts an onnx ArgMax node into an equivalent Relax expression."""

tests/python/relax/test_frontend_onnx.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,23 +1580,22 @@ def verify_reduce_func(func, data, axis, keepdims):
15801580
def create_reduce_test_parameters_axes_input():
15811581
output = []
15821582
for dynamic in [True, False]:
1583-
# TODO(@vacu9708): Enable the tests after implementing other reduce ops
1584-
# output.append(("ReduceMax", dynamic, 20))
1585-
# output.append(("ReduceMean", dynamic, 18))
1586-
# output.append(("ReduceMin", dynamic, 20))
1587-
# output.append(("ReduceProd", dynamic, 18))
1588-
# output.append(("ReduceSum", dynamic, 13))
1589-
# output.append(("ReduceSumSquare", dynamic, 18))
1590-
# output.append(("ReduceLogSum", dynamic, 18))
1591-
# output.append(("ReduceLogSumExp", dynamic, 18))
1583+
output.append(("ReduceMax", dynamic, 18))
1584+
output.append(("ReduceMean", dynamic, 18))
1585+
output.append(("ReduceMin", dynamic, 18))
1586+
output.append(("ReduceProd", dynamic, 18))
1587+
output.append(("ReduceSum", dynamic, 13))
1588+
output.append(("ReduceSumSquare", dynamic, 18))
1589+
output.append(("ReduceLogSum", dynamic, 18))
1590+
output.append(("ReduceLogSumExp", dynamic, 18))
15921591
output.append(("ReduceL1", dynamic, 18))
1593-
# output.append(("ReduceL2", dynamic, 18))
1592+
output.append(("ReduceL2", dynamic, 18))
15941593
return output
15951594

15961595

15971596
@pytest.mark.parametrize("func, dynamic, opset", create_reduce_test_parameters_axes_input())
15981597
def test_all_reduce_funcs_axes_input(func, dynamic, opset):
1599-
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
1598+
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes=False):
16001599
inshape = data.shape
16011600

16021601
inputs = ["x"]
@@ -1698,7 +1697,6 @@ def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
16981697
np.random.randn(3, 3, 3, 1).astype(np.float32),
16991698
axes=(1, 2),
17001699
keepdims=keepdims,
1701-
noop_with_empty_axes=True,
17021700
)
17031701

17041702

0 commit comments

Comments
 (0)