Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2420,6 +2420,8 @@ def _impl_v1(cls, inputs, attr, params):
axis += ndim
if axis == 0:
reshape_shape = [-1]
elif axis == ndim - 1:
return _op.nn.softmax(inputs[0], axis=axis)
else:
axis_val = [in_shape[i] for i in range(axis)]
reshape_shape = [np.prod(axis_val)] + [-1]
Expand Down
35 changes: 28 additions & 7 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,31 +1591,52 @@ def test_upsample3d_trilinear(target, dev):

@tvm.testing.parametrize_targets
def test_softmax(target, dev):
def verify_softmax(inshape, axis):
def verify_softmax(inshape, axis, opset=None, dynamic=False):
opname = "Softmax"
indata = np.random.uniform(size=inshape).astype(np.float32)
outshape = inshape
y = helper.make_node(opname, ["in"], ["out"])
node_list = []
input_node_list = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(inshape))]
output_node_list = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))]
input_list = [np.random.uniform(size=inshape).astype(np.float32)]
softmax_inputs = ["in"]

if dynamic:
input_node_list.append(
helper.make_tensor_value_info("shape", TensorProto.INT64, [len(inshape)])
)
input_list.append(np.asarray(inshape))
reshape_node = helper.make_node("Reshape", ["in", "shape"], ["dynamic_in"])
softmax_inputs[0] = "dynamic_in"
node_list += [reshape_node]

y = helper.make_node(opname, softmax_inputs, ["out"])
if axis is not None:
axis_attr = helper.make_attribute("axis", axis)
y.attribute.append(axis_attr)
node_list.append(y)

graph = helper.make_graph(
[y],
node_list,
opname + "_test",
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))],
inputs=input_node_list,
outputs=output_node_list,
)

model = helper.make_model(graph, producer_name=opname + "_test")
verify_with_ort_with_inputs(model, [indata], target=target, dev=dev)
verify_with_ort_with_inputs(
model, input_list, use_vm=True, opset=opset, target=target, dev=dev
)

verify_softmax((1, 10), None)
verify_softmax((1, 10), 1)
verify_softmax((1, 2, 3, 10), 0)
verify_softmax((1, 2, 3, 10), 2)
verify_softmax((1, 2, 3, 4, 10), 3)
verify_softmax((1, 2, 3, 4, 10), 4)
verify_softmax((1, 10), -1, dynamic=True)
verify_softmax((1, 2, 3, 10), -1, dynamic=True)
verify_softmax((1, 10), -1, opset=8, dynamic=True)
verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True)


@tvm.testing.parametrize_targets
Expand Down