Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix softmax, logsoftmax failed on empty ndarray #18602

Merged
merged 2 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
46 changes: 32 additions & 14 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,14 @@ def __init__(self, axis):
def hybrid_forward(self, F, a):
return F.npx.softmax(a, axis=axis)

class TestLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestLogSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.log_softmax(a, axis=axis)

def np_softmax(x, axis=-1):
if (x.shape[axis] == 0):
return _np.sum(x, axis=axis, keepdims=True)
Expand All @@ -1930,24 +1938,34 @@ def np_softmax(x, axis=-1):
x /= _np.sum(x, axis=axis, keepdims=True)
return x

def np_log_softmax(x, axis=-1):
return _np.log(np_softmax(x, axis))

#(operator, function) tuples
tested_ops = [(TestSoftmax, np_softmax),
(TestLogSoftmax, np_log_softmax)]

# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax = TestSoftmax(axis)
if hybridize:
test_softmax.hybridize()
for SoftmaxOp, softmax_function in tested_ops:
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()

with mx.autograd.record():
mx_out = test_softmax(mx_a)
with mx.autograd.record():
mx_out = test_softmax_op(mx_a)

np_out = np_softmax(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
mx_out.wait_to_read()

mx_out.backward()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)
np_out = softmax_function(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)

mx_out.backward()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)


@with_seed()
Expand Down