Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix SoftmaxOutput resource bug (#14302)
Browse files Browse the repository at this point in the history
* fix SoftmaxOutput resource

* remove BackwardResource since SoftmaxOutput is not a legacy op

* add test_SoftmaxOutput_normalization

* igore_label=-1 when use_ignore=False

* retrigger CI

* add multi_output test for SoftmaxOutput

* rename test_SoftmaxOutput_normalization to test_softmax_output_normalization

* retrigger CI

* retrigger CI

* fix test bug
  • Loading branch information
wkcn authored Mar 5, 2019
1 parent 7243806 commit 5065f13
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,6 @@ class SoftmaxOutputProp : public OperatorProperty {
return {{in_data[softmaxout_enum::kData], out_data[softmaxout_enum::kOut]}};
}

std::vector<ResourceRequest> BackwardResource(
const mxnet::ShapeVector &in_shape) const override {
return {ResourceRequest::kTempSpace};
}

Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
Expand Down
3 changes: 3 additions & 0 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ NNVM_REGISTER_OP(_backward_SoftmaxOutput)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n){
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr_parser(ParamParser<SoftmaxOutputParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxOutputGradCompute<cpu>);
} // namespace op
Expand Down
71 changes: 71 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6496,6 +6496,75 @@ def test_softmax():
check_smoothed_softmax_grad(default_context())


@with_seed()
def test_softmax_output_normalization():
def _softmaxoutput_normalization(multi_output, use_ignore, normalization):
grad_scale = np.random.random()
batch_size = 8
num_labels = 6
H, W = 3, 3
ignore_label = np.random.randint(0, num_labels) if use_ignore else -1

if multi_output:
data_shape = (batch_size, num_labels, H, W)
label_shape = (batch_size, H, W)
else:
data_shape = (batch_size, num_labels)
label_shape = (batch_size, )

data = mx.nd.random.uniform(-1, 1, shape=data_shape)
label = mx.nd.random.randint(
0, num_labels, shape=label_shape).astype('float32')
data.attach_grad()

kwargs = dict(grad_scale=grad_scale,
normalization=normalization, multi_output=multi_output)
if use_ignore:
kwargs.update(use_ignore=True, ignore_label=ignore_label)

with mx.autograd.record():
out = mx.nd.SoftmaxOutput(data=data, label=label, **kwargs)
out.backward(mx.nd.ones_like(data))

exp_data = mx.nd.exp(data)
softmax_data = exp_data / exp_data.sum(1, keepdims=True)
argmax_data = mx.nd.argmax(data, axis=1)

assert_almost_equal(out.asnumpy(), softmax_data.asnumpy())
one_hot_label = mx.nd.one_hot(label, num_labels)
if multi_output:
one_hot_label = one_hot_label.transpose((0, 3, 1, 2))
data_grad = softmax_data - one_hot_label

if use_ignore:
if multi_output:
data_grad *= (label !=
ignore_label).reshape((batch_size, 1, H, W))
else:
data_grad *= (label != ignore_label).reshape((batch_size, 1))

valid_cnt = 1
if normalization == 'batch':
valid_cnt = batch_size
elif normalization == 'valid':
valid_cnt = mx.nd.maximum(1, (label != ignore_label).sum())
scale = grad_scale / valid_cnt

if multi_output:
if normalization != 'valid':
scale /= H * W

data_grad *= scale

assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy())

for multi_output in [False, True]:
for use_ignore in [False, True]:
for normalization in ['null', 'batch', 'valid']:
_softmaxoutput_normalization(
multi_output, use_ignore, normalization)


@with_seed()
def test_slice():
def test_slice_forward_backward(a, index):
Expand Down Expand Up @@ -6533,6 +6602,7 @@ def test_slice_forward_backward(a, index):
slice_sym = mx.sym.slice(data, begin=[0, None], end=[1, None], step=[2, -1])
check_numeric_gradient(slice_sym, [in_data])


def test_slice_partial_infer():
def check_slice_partial_infer(data, begin, end, step, expected_out_shape):
out = mx.sym.slice(data, begin=begin, end=end, step=step)
Expand All @@ -6555,6 +6625,7 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape):
check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0))
check_slice_axis_partial_infer(var1, 1, 0, 5, (10, 0))


@with_seed()
def test_float16_min_max():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/9007"""
Expand Down

0 comments on commit 5065f13

Please sign in to comment.