Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enable fused softmax smoothing #8072

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct SoftmaxOutputParam : public dmlc::Parameter<SoftmaxOutputParam> {
bool preserve_shape;
int normalization;
bool out_grad;
float smooth_alpha;
DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) {
DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f)
.describe("Scales the gradient by a float factor.");
Expand All @@ -78,6 +79,10 @@ struct SoftmaxOutputParam : public dmlc::Parameter<SoftmaxOutputParam> {
DMLC_DECLARE_FIELD(out_grad)
.set_default(false)
.describe("Multiplies gradient with output gradient element-wise.");
DMLC_DECLARE_FIELD(smooth_alpha)
.set_default(0.0f)
.set_range(0.0f,1.0f)
.describe("Constant for smoothed cross-entropy gradients.");
};
};

Expand Down Expand Up @@ -215,9 +220,17 @@ class SoftmaxOutputOp : public Operator {
in_grad[softmaxout_enum::kData].get_with_shape<xpu, 2, DType>(data_shape, s);
index_t valid_cnt = label.shape_.Size();
if (param_.use_ignore) {
SoftmaxGrad(grad, out, label, static_cast<DType>(param_.ignore_label));
if (param_.smooth_alpha == 0.0f) {
SoftmaxGrad(grad, out, label, static_cast<DType>(param_.ignore_label));
} else {
SmoothSoftmaxGrad(grad, out, label, static_cast<DType>(param_.ignore_label), param_.smooth_alpha);
}
} else {
SoftmaxGrad(grad, out, label);
if (param_.smooth_alpha == 0.0f) {
SoftmaxGrad(grad, out, label);
} else {
SmoothSoftmaxGrad(grad, out, label, param_.smooth_alpha);
}
}
if (param_.normalization == softmaxout_enum::kBatch) {
valid_cnt = label.size(0);
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,55 @@ def test_regression():
lambda x, y : x - y)


def check_softmax_grad():
x = mx.sym.Variable('x')
label = mx.sym.Variable('label')
x_nd = mx.nd.array([[1, 6, 4, 2]], ctx=default_context())
grad_x = mx.nd.zeros((1,4), ctx=default_context())
label_nd = mx.nd.array([1], ctx=default_context())

sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, use_ignore=False)
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, args_grad={'x': grad_x})

ex.forward(is_train=True)
softmax_out = ex.outputs[0].asnumpy()
expected_softmax_out = mx.nd.softmax(x_nd).asnumpy()
print(softmax_out)
assert np.isclose(softmax_out, expected_softmax_out).all()

ex.backward(is_train=True)
grad_out = ex.grad_arrays[0].asnumpy()
k = int(label_nd[0].asscalar())
expected_grad_out = np.zeros((1,4))
expected_grad_out[0, k] = -1
assert np.isclose(grad_out - softmax_out, expected_grad_out).all()


def check_smoothed_softmax_grad():
alpha = 0.2
x = mx.sym.Variable('x')
label = mx.sym.Variable('label')
x_nd = mx.nd.array([[1, 6, 4, 2]], ctx=default_context())
grad_x = mx.nd.zeros((1,4), ctx=default_context())
label_nd = mx.nd.array([1], ctx=default_context())

sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, use_ignore=False, smooth_alpha=alpha)
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, args_grad={'x': grad_x})

ex.forward(is_train=True)
softmax_out = ex.outputs[0].asnumpy()
expected_softmax_out = mx.nd.softmax(x_nd).asnumpy()
print(softmax_out)
assert np.isclose(softmax_out, expected_softmax_out).all()

ex.backward(is_train=True)
grad_out = ex.grad_arrays[0].asnumpy()
k = int(label_nd[0].asscalar())
expected_grad_out = np.full((1,4), fill_value=-alpha/(4-1))
expected_grad_out[0, k] = - (1 - alpha)
assert np.isclose(grad_out - softmax_out, expected_grad_out).all()


def check_softmax_with_ignore_label(xpu):
X = mx.symbol.Variable('X')
L = mx.symbol.Variable('L')
Expand Down Expand Up @@ -289,6 +338,8 @@ def test_softmax():
check_softmax_with_shape((3, 4), default_context(), preserve_shape=False)
check_softmax_with_shape((3, 4), default_context(), preserve_shape=True)
check_softmax_with_shape((3, 4, 2), default_context(), preserve_shape=True)
check_smoothed_softmax_grad()
check_smoothed_softmax_grad()


def test_python_op():
Expand Down