Skip to content

Commit 4f5ab57

Browse files
authored
[Frontend][ONNX] Fix softmax converter when input shape is dynamic (#11507)
* [Frontend][ONNX] Fix softmax converter when input shape is dynamic * [Frontend][ONNX] mark dynamic softmax tests as xfailed with cuda
1 parent bbca53d commit 4f5ab57

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,8 @@ def _impl_v1(cls, inputs, attr, params):
24202420
axis += ndim
24212421
if axis == 0:
24222422
reshape_shape = [-1]
2423+
elif axis == ndim - 1:
2424+
return _op.nn.softmax(inputs[0], axis=axis)
24232425
else:
24242426
axis_val = [in_shape[i] for i in range(axis)]
24252427
reshape_shape = [np.prod(axis_val)] + [-1]

tests/python/frontend/onnx/test_forward.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,33 +1589,56 @@ def test_upsample3d_trilinear(target, dev):
15891589
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
15901590

15911591

1592+
# TODO: Fix softmax with dynamic input on cuda and enable this test
1593+
@tvm.testing.known_failing_targets("cuda")
15921594
@tvm.testing.parametrize_targets
15931595
def test_softmax(target, dev):
1594-
def verify_softmax(inshape, axis):
1596+
def verify_softmax(inshape, axis, opset=None, dynamic=False):
15951597
opname = "Softmax"
1596-
indata = np.random.uniform(size=inshape).astype(np.float32)
15971598
outshape = inshape
1598-
y = helper.make_node(opname, ["in"], ["out"])
1599+
node_list = []
1600+
input_node_list = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(inshape))]
1601+
output_node_list = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))]
1602+
input_list = [np.random.uniform(size=inshape).astype(np.float32)]
1603+
softmax_inputs = ["in"]
1604+
1605+
if dynamic:
1606+
input_node_list.append(
1607+
helper.make_tensor_value_info("shape", TensorProto.INT64, [len(inshape)])
1608+
)
1609+
input_list.append(np.asarray(inshape))
1610+
reshape_node = helper.make_node("Reshape", ["in", "shape"], ["dynamic_in"])
1611+
softmax_inputs[0] = "dynamic_in"
1612+
node_list += [reshape_node]
1613+
1614+
y = helper.make_node(opname, softmax_inputs, ["out"])
15991615
if axis is not None:
16001616
axis_attr = helper.make_attribute("axis", axis)
16011617
y.attribute.append(axis_attr)
1618+
node_list.append(y)
16021619

16031620
graph = helper.make_graph(
1604-
[y],
1621+
node_list,
16051622
opname + "_test",
1606-
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
1607-
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))],
1623+
inputs=input_node_list,
1624+
outputs=output_node_list,
16081625
)
16091626

16101627
model = helper.make_model(graph, producer_name=opname + "_test")
1611-
verify_with_ort_with_inputs(model, [indata], target=target, dev=dev)
1628+
verify_with_ort_with_inputs(
1629+
model, input_list, use_vm=True, opset=opset, target=target, dev=dev
1630+
)
16121631

16131632
verify_softmax((1, 10), None)
16141633
verify_softmax((1, 10), 1)
16151634
verify_softmax((1, 2, 3, 10), 0)
16161635
verify_softmax((1, 2, 3, 10), 2)
16171636
verify_softmax((1, 2, 3, 4, 10), 3)
16181637
verify_softmax((1, 2, 3, 4, 10), 4)
1638+
verify_softmax((1, 10), -1, dynamic=True)
1639+
verify_softmax((1, 2, 3, 10), -1, dynamic=True)
1640+
verify_softmax((1, 10), -1, opset=8, dynamic=True)
1641+
verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True)
16191642

16201643

16211644
@tvm.testing.parametrize_targets

0 commit comments

Comments
 (0)