Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-807] Support integer label type in ctc_loss operator #12468

Merged
merged 17 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 84 additions & 68 deletions src/operator/contrib/ctc_loss-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,66 +256,69 @@ class CTCLossOp : public Operator {
exceed_cudnn_limit = false;
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 3, real_t> data =
MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, {
Tensor<xpu, 3, real_t> data =
in_data[ctc_loss::kData].get<xpu, 3, real_t>(s);
Tensor<xpu, 2, real_t> labels =
in_data[ctc_loss::kLabel].get<xpu, 2, real_t>(s);
Tensor<xpu, 2, DType> labels =
in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s);

Tensor<xpu, 1, real_t> costs =
Tensor<xpu, 1, real_t> costs =
out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s);
Tensor<xpu, 3, real_t> grad =
Tensor<xpu, 3, real_t> grad =
out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);

int max_seq_len = data.size(0);
int batch_size = data.size(1);
int alphabet_size = data.size(2);

// data_lengths
std::vector<int> data_lengths(batch_size, max_seq_len);
if (param_.use_data_lengths) {
int kInputLength = 2;
IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
}

// label_lengths
std::vector<int> packed_labels;
std::vector<int> label_lengths(batch_size);

if (param_.use_label_lengths) {
int kLabelLength = 2+param_.use_data_lengths;
exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, real_t>(s),
&packed_labels, &label_lengths);
} else {
exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0?0:-1,
&packed_labels, &label_lengths);
}

// CUDNN is disabled due to lack of support for input lengths
/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
/* if (!exceed_cudnn_limit) { */
/* cudnn_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* max_seq_len, batch_size, alphabet_size, */
/* req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } else { */
/* baidu_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } */
/* #else */

baidu_forward(ctx, s, data, costs, grad,
&data_lengths, &label_lengths, &packed_labels,
batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);

if (param_.use_data_lengths) {
// baidu warp CTC implementation sometimes includes undefined gradients
// for data outside of length mask. Setting to 0 to make it consistent
// with CPU implementation.
int kInputLength = 2;
mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
static_cast<real_t>(0));
}
int max_seq_len = data.size(0);
int batch_size = data.size(1);
int alphabet_size = data.size(2);

// data_lengths
std::vector<int> data_lengths(batch_size, max_seq_len);
if (param_.use_data_lengths) {
int kInputLength = 2;
IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
}

// label_lengths
std::vector<int> packed_labels;
std::vector<int> label_lengths(batch_size);

if (param_.use_label_lengths) {
int kLabelLength = 2 + param_.use_data_lengths;
exceed_cudnn_limit =
PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s),
&packed_labels, &label_lengths);
} else {
exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1,
&packed_labels, &label_lengths);
}

// CUDNN is disabled due to lack of support for input lengths
/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
/* if (!exceed_cudnn_limit) { */
/* cudnn_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* max_seq_len, batch_size, alphabet_size, */
/* req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } else { */
/* baidu_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/
/* } */
/* #else */

baidu_forward(ctx, s, data, costs, grad,
&data_lengths, &label_lengths, &packed_labels,
batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);

if (param_.use_data_lengths) {
// baidu warp CTC implementation sometimes includes undefined gradients
// for data outside of length mask. Setting to 0 to make it consistent
// with CPU implementation.
int kInputLength = 2;
mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
static_cast<real_t>(0));
}
});
}

virtual void Backward(const OpContext &ctx,
Expand Down Expand Up @@ -434,17 +437,17 @@ class CTCLossOp : public Operator {
}
#endif // __CUDACC__ && CUDNN

inline virtual void baidu_forward(const OpContext &ctx,
mshadow::Stream<xpu>* s,
mshadow::Tensor<xpu, 3, real_t> data,
mshadow::Tensor<xpu, 1, real_t> costs,
mshadow::Tensor<xpu, 3, real_t> grad,
std::vector<int>* data_lengths,
std::vector<int>* label_lengths,
std::vector<int>* packed_labels,
int batch_size,
int alphabet_size,
bool req_grad) {
inline void baidu_forward(const OpContext &ctx,
mshadow::Stream<xpu>* s,
mshadow::Tensor<xpu, 3, real_t> data,
mshadow::Tensor<xpu, 1, real_t> costs,
mshadow::Tensor<xpu, 3, real_t> grad,
std::vector<int>* data_lengths,
std::vector<int>* label_lengths,
std::vector<int>* packed_labels,
int batch_size,
int alphabet_size,
bool req_grad) {
using namespace mshadow;
// allocate temporary workspace
size_t size_bytes;
Expand All @@ -461,7 +464,7 @@ class CTCLossOp : public Operator {
compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(),
label_lengths->data(), data_lengths->data(),
workspace.dptr_, req_grad,
param_.blank_label == 0?0:(alphabet_size-1));
param_.blank_label == 0 ? 0 : (alphabet_size-1));
}
}; // class CTCLossOp

Expand Down Expand Up @@ -534,11 +537,24 @@ class CTCLossProp : public OperatorProperty {
TShape oshape(1);
oshape[0] = dshape[1]; // batch size
out_shape->clear();
out_shape->push_back(oshape);
out_shape->push_back(oshape); // forward output
out_shape->push_back(dshape); // grad output
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_LE(in_type->size(), this->ListArguments().size());
int dtype = (*in_type)[ctc_loss::kData];
CHECK_NE(dtype, -1) << "Input data must have specified type";

out_type->clear();
out_type->push_back(dtype); // forward output
out_type->push_back(dtype); // grad output
return true;
}

OperatorProperty *Copy() const override {
auto ptr = new CTCLossProp();
ptr->param_ = param_;
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def assert_match(inputs, x, y, threshold, is_ascend=False):
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False)
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True)


if __name__ == '__main__':
import nose
nose.runmodule()
19 changes: 19 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4516,6 +4516,25 @@ def test_ctc_loss():
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
check_ctc_loss(acts2, labels2, true_loss)

# Test 3: check use integer type as label
labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32)
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
check_ctc_loss(acts2, labels3, true_loss)

@with_seed()
def test_ctc_loss_with_large_classes():
ctx = default_context()
m = 1024
n = 35
l = 10
num_classes = 6000
x = np.random.uniform(size=(n, m, num_classes))
y = np.random.randint(0, num_classes, size=(m, l))

data = mx.nd.array(x, ctx=ctx)
label = mx.nd.array(y, ctx=ctx)
loss = mx.nd.contrib.ctc_loss(data=data, label=label)
assert loss.asnumpy().shape[0] == m
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you testing for shape?

Copy link
Contributor Author

@apeforest apeforest Sep 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to test the operator does not crash upon large number of classes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not crash on the master branch without the change either.

Copy link
Contributor Author

@apeforest apeforest Sep 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. This unit test is not to test my fix. It is to test an earlier PR #11834 which did not include a unit test but was merged somehow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that. Still, the batch size is unnecessarily large. Why not make the test run faster? Also, there's still no test that covers the loss of precision problem that the integer label type solves, which is part of your fix. Would you mind adding that please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the batch size to 2.


@with_seed()
def test_ctc_loss_grad():
Expand Down