diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py index cb6d5b321eac..2d6921b26dfa 100644 --- a/python/tvm/topi/nn/softmax.py +++ b/python/tvm/topi/nn/softmax.py @@ -136,16 +136,38 @@ def log_softmax(x, axis=-1): output : tvm.te.Tensor 2-D output with same shape """ - assert len(x.shape) == 2, "only support 2-dim log softmax" - # pylint: disable=R1714 - assert axis == -1 or axis == len(x.shape) - 1, "only support last axis log softmax" - m, n = x.shape - k = te.reduce_axis((0, n), name="k") - max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k)) - k = te.reduce_axis((0, n), name="k") - expsum = te.compute((m,), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k)) + shape = x.shape + if axis < 0: + axis = len(shape) + axis + if axis >= len(shape): + ValueError("axis parameter should be less than input dim") + + k1 = te.reduce_axis((0, shape[axis]), name="k") + k2 = te.reduce_axis((0, shape[axis]), name="k") + + def insert_reduce_index(indices, reduce_index): + return indices[:axis] + (reduce_index,) + indices[axis:] + + def get_non_reduce_indices(indices): + return tuple([var for (i, var) in enumerate(indices) if i != axis]) + + def _compute_max(*indices): + eval_range = insert_reduce_index(indices, k1) + return tvm.te.max(x[eval_range], axis=k1) + + def _compute_expsum(max_elem, *indices): + eval_range = insert_reduce_index(indices, k2) + return te.sum(te.exp(x[eval_range] - max_elem[indices]), axis=k2) + + def _normalize(max_elem, expsum, *indices): + non_reduce_indices = get_non_reduce_indices(indices) + return x[indices] - max_elem[non_reduce_indices] - te.log(expsum[non_reduce_indices]) + + reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis]) + max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem") + expsum = te.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices)) return te.compute( - x.shape, - lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]), + shape, + lambda *indices: _normalize(max_elem, expsum, *indices), attrs={"axis": axis}, ) diff --git a/python/tvm/topi/testing/softmax_python.py b/python/tvm/topi/testing/softmax_python.py index da2893d1fa7b..6be5d48a671a 100644 --- a/python/tvm/topi/testing/softmax_python.py +++ b/python/tvm/topi/testing/softmax_python.py @@ -19,43 +19,39 @@ import numpy as np -def softmax_python(a_np): +def softmax_python(a_np, axis=1): """Softmax operator. Parameters ---------- a_np : numpy.ndarray - 2-D input data + N-D input data Returns ------- output_np : numpy.ndarray - 2-D output with same shape + N-D output with same shape """ - assert len(a_np.shape) == 2, "only support 2-dim softmax" - max_elem = np.amax(a_np, axis=1) - max_elem = max_elem.reshape(max_elem.shape[0], 1) + max_elem = np.amax(a_np, axis=axis, keepdims=True) e = np.exp(a_np - max_elem) - expsum = np.sum(e, axis=1) - out_np = e / expsum[:, None] + expsum = np.sum(e, axis=axis, keepdims=True) + out_np = e / expsum return out_np -def log_softmax_python(a_np): +def log_softmax_python(a_np, axis=1): """Log_softmax operator. Parameters ---------- a_np : numpy.ndarray - 2-D input data + N-D input data Returns ------- output_np : numpy.ndarray - 2-D output with same shape + N-D output with same shape """ - assert len(a_np.shape) == 2, "only support 2-dim log_softmax" - max_elem = np.amax(a_np, axis=1) - max_elem = max_elem.reshape(max_elem.shape[0], 1) + max_elem = np.amax(a_np, axis=axis, keepdims=True) e = np.exp(a_np - max_elem) - expsum = np.sum(e, axis=1) - out_np = a_np - max_elem - np.log(expsum[:, None]) + expsum = np.sum(e, axis=axis, keepdims=True) + out_np = a_np - max_elem - np.log(expsum) return out_np diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py index 9b6754c5e847..5475fc772e77 100644 --- a/python/tvm/topi/x86/nn.py +++ b/python/tvm/topi/x86/nn.py @@ -39,7 +39,7 @@ def _schedule_softmax(softmax_op, s, outs): delta = None max_elem = softmax_op.input_tensors[1] expsum = softmax_op.input_tensors[2] - axis = 1 + axis = int(softmax_op.attrs["axis"]) else: raise ValueError( "Tag is expected to be softmax_output or log_softmax_output. \ diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 44df40d3b0bd..4ce422ae8893 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -249,46 +249,48 @@ def test_expand_dims_infer_type(): @tvm.testing.uses_gpu def test_softmax(): - for dtype in ["float16", "float32"]: - # Softmax accuracy for float16 is poor - if dtype == "float16": - return - shape = (10, 4) - x = relay.var("x", shape=shape, dtype=dtype) - y = relay.nn.softmax(x, axis=1) - assert "nn.softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape, dtype) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype(dtype) - ref_res = tvm.topi.testing.softmax_python(x_data) - for target, dev in tvm.testing.enabled_targets(): - op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( - x_data - ) - np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) + for shape in [(10, 4), (10, 5, 4)]: + for dtype in ["float16", "float32"]: + # Softmax accuracy for float16 is poor + if dtype == "float16": + continue + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.softmax(x, axis=1) + assert "nn.softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = tvm.topi.testing.softmax_python(x_data, axis=1) + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) + np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @tvm.testing.uses_gpu def test_log_softmax(): - for dtype in ["float16", "float32"]: - # Softmax accuracy for float16 is poor - if dtype == "float16": - return - shape = (10, 4) - x = relay.var("x", shape=shape, dtype=dtype) - y = relay.nn.log_softmax(x, axis=1) - assert "nn.log_softmax" in y.astext() - yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(shape, dtype) - func = relay.Function([x], y) - x_data = np.random.uniform(size=shape).astype(dtype) - ref_res = tvm.topi.testing.log_softmax_python(x_data) - for target, dev in tvm.testing.enabled_targets(): - op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( - x_data - ) - np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) + for shape in [(10, 4), (10, 5, 4)]: + for dtype in ["float16", "float32"]: + # Softmax accuracy for float16 is poor + if dtype == "float16": + continue + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.log_softmax(x, axis=1) + assert "nn.log_softmax" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType(shape, dtype) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + ref_res = tvm.topi.testing.log_softmax_python(x_data, axis=1) + for target, dev in tvm.testing.enabled_targets(): + if target == "nvptx": + continue + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) + np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_softmax.py b/tests/python/topi/python/test_topi_softmax.py index 8243211a8674..8e5e039b1448 100644 --- a/tests/python/topi/python/test_topi_softmax.py +++ b/tests/python/topi/python/test_topi_softmax.py @@ -50,7 +50,7 @@ "log_softmax": { "topi": topi.nn.log_softmax, "ref": tvm.topi.testing.log_softmax_python, - "dimensions": [2], + "dimensions": [2, 3], "axis": [1], }, }