Skip to content

Commit

Permalink
Fix softmax, logsoftmax failed on empty ndarray (apache#18602)
Browse files Browse the repository at this point in the history
* Fix failing empty array (log_)softmax

* Modify test for npx (log_)softmax
  • Loading branch information
bgawrych authored and Bart Gawrych committed Jul 31, 2020
1 parent ef953fe commit 0787b56
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
46 changes: 32 additions & 14 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,14 @@ def __init__(self, axis):
def hybrid_forward(self, F, a):
return F.npx.softmax(a, axis=axis)

class TestLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestLogSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.log_softmax(a, axis=axis)

def np_softmax(x, axis=-1):
if (x.shape[axis] == 0):
return _np.sum(x, axis=axis, keepdims=True)
Expand All @@ -1577,24 +1585,34 @@ def np_softmax(x, axis=-1):
x /= _np.sum(x, axis=axis, keepdims=True)
return x

def np_log_softmax(x, axis=-1):
return _np.log(np_softmax(x, axis))

#(operator, function) tuples
tested_ops = [(TestSoftmax, np_softmax),
(TestLogSoftmax, np_log_softmax)]

# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax = TestSoftmax(axis)
if hybridize:
test_softmax.hybridize()
for SoftmaxOp, softmax_function in tested_ops:
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()

with mx.autograd.record():
mx_out = test_softmax(mx_a)
with mx.autograd.record():
mx_out = test_softmax_op(mx_a)

np_out = np_softmax(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
mx_out.wait_to_read()

mx_out.backward()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)
np_out = softmax_function(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)

mx_out.backward()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)


@with_seed()
Expand Down

0 comments on commit 0787b56

Please sign in to comment.