Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
contrib ctc interface changes for compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Aug 14, 2017
1 parent 568b5a2 commit 245a789
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 18 deletions.
101 changes: 101 additions & 0 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from __future__ import absolute_import

from .. import ndarray
from ..contrib import symbol as symbol_contrib
from ..contrib import ndarray as ndarray_contrib
from ..base import numeric_types
from .block import HybridBlock

Expand Down Expand Up @@ -295,3 +297,102 @@ def hybrid_forward(self, F, output, label, sample_weight=None):
loss = label * (F.log(label+1e-8) - output)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


class CTCLoss(Loss):
r"""Connectionist Temporal Classification Loss.
See `"Connectionist Temporal Classification: Labelling Unsegmented
Sequence Data with Recurrent Neural Networks"
<http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ paper for more information.
The prediction output should be an activation vector without softmax, with shape
according to the output_layout:
**TNC**: *(sequence_length, batch_size, alphabet_size + 1)*
**NTC**: *(batch_size, sequence_length, alphabet_size + 1)*
The loss output has the shape:
**loss**: *(batch_size,)*.
``label`` is a tensor of integers between 1 and *alphabet_size*, with shape according
to the label_layout:
**NT**: *(batch_size, label_sequence_length)*
**TN**: *(label_sequence_length, batch_size)*
If a sequence of labels is shorter than *label_sequence_length*, use the special
padding character 0 at the end of the sequence to conform it to the correct
length. For example, if *label_sequence_length* = 4, and one has two sequences
of labels [2, 1] and [3, 2, 2], the resulting ```label``` tensor should be
padded to be::
[[2, 1, 0, 0], [3, 2, 2, 0]]
Parameters
----------
output_layout : str, default 'NTC'
Layout of the output sequence activation vector.
label_layout : str, default 'NT'
Layout of the labels.
use_input_lengths : bool, default False
Whether to use `input_lengths` to decide lengths of inputs.
If false, the input lengths are treated as being equal to the max sequence length.
use_label_lengths : bool, default False
Whether to use `label_lengths` to decide lengths of labels.
If false, the label lengths are derived from the first occurrence of
the value specified by `padding_mask`.
padding_mask : int or None, default -1
This is the label value to be considered padding, which is used to derive the actual
lengths of labels. Only required when `use_label_lengths` is false.
weight : float or None
Global scalar weight for loss.
input_lengths : NDArray or None,
Actual lengths of inputs. Only required when `use_input_lengths` is true.
This should be used as the third argument when calling this loss.
The shape should be (N,).
label_lengths : NDArray or None,
Lengths of labels. Only required when `use_label_lengths` is true.
This should be used as the fourth argument when calling this loss.
The shape should be (N,).
sample_weight : Symbol or None
Per sample weighting. Must be broadcastable to
the same shape as loss. For example, if loss has
shape (64, 10) and you want to weight each sample
in the batch, `sample_weight` should have shape (64, 1).
This should be used as the fifth argument when calling this loss.
"""
def __init__(self, output_layout='NTC', label_layout='NT',
use_input_lengths=False, use_label_lengths=False, padding_mask=-1,
weight=None, **kwargs):
assert output_layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for output are supported. Got: %s"%output_layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
self._output_layout = output_layout
self._label_layout = label_layout
self._use_input_lengths = use_input_lengths
self._use_label_lengths = use_label_lengths
self._padding_mask = padding_mask
batch_axis = label_layout.find('N')
super(CTCLoss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, output, label,
input_lengths=None, label_lengths=None, sample_weight=None):
assert not self._use_input_lengths or input_lengths is not None, \
"Must specify input_lengths."
assert not self._use_label_lengths or label_lengths is not None, \
"Must specify label_lengths."
if self._output_layout == 'NTC':
output = F.swapaxes(output, 0, 1)
if self._batch_axis == 1:
label = F.swapaxes(label, 0, 1)
if F is ndarray:
F_contrib = ndarray_contrib
else:
F_contrib = symbol_contrib
loss = F_contrib.CTCLoss(output, label,
use_input_lengths=self._use_input_lengths,
use_label_lengths=self._use_label_lengths,
input_lengths=input_lengths, label_lengths=label_lengths,
padding_mask=self._padding_mask)
return _apply_weighting(F, loss, self._weight, sample_weight)
93 changes: 83 additions & 10 deletions src/operator/contrib/ctc_loss-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,60 @@ inline void get_workspace_size(std::vector<int> *label_lengths,
// characters. The sequence lengths are also inferred from the padding chars
template <typename DType, typename xpu>
inline void LabelTensorToPackedVector(mshadow::Tensor<xpu, 2, DType> labels,
int padding_mask,
std::vector<int> *packed_labels,
std::vector<int> *label_lengths) {
int batch = labels.size(0);
int max_num_labels = labels.size(1);
std::vector<index_t> cpu_labels(max_num_labels);
std::vector<int> cpu_labels(max_num_labels);

for (int b = 0; b < batch; ++b) {
IndexTensorToVector(labels[b], &cpu_labels);
auto res = std::find(cpu_labels.begin(), cpu_labels.end(), 0);
auto res = std::find(cpu_labels.begin(), cpu_labels.end(), padding_mask);
int len = std::distance(cpu_labels.begin(), res);
std::copy(cpu_labels.begin(), cpu_labels.begin() + len,
std::back_inserter(*packed_labels));
label_lengths->emplace_back(len);
label_lengths->at(b) = len;
}
}

template <typename DType, typename xpu>
inline void PackLabelByLength(mshadow::Tensor<xpu, 2, DType> labels,
mshadow::Tensor<xpu, 1, DType> in_label_lengths,
std::vector<int> *packed_labels,
std::vector<int> *label_lengths) {
int batch = labels.size(0);
int max_num_labels = labels.size(1);
std::vector<int> cpu_labels(max_num_labels);
IndexTensorToVector(in_label_lengths, label_lengths);

for (int b = 0; b < batch; ++b) {
IndexTensorToVector(labels[b], &cpu_labels);
int len = label_lengths->at(b);
std::copy(cpu_labels.begin(), cpu_labels.begin() + len,
std::back_inserter(*packed_labels));
}
}

struct CTCLossParam : public dmlc::Parameter<CTCLossParam> {
DMLC_DECLARE_PARAMETER(CTCLossParam) {}
bool use_input_lengths;
bool use_label_lengths;
dmlc::optional<int> padding_mask;
DMLC_DECLARE_PARAMETER(CTCLossParam) {
DMLC_DECLARE_FIELD(use_input_lengths).set_default(false)
.describe("Whether the input lenghts are decided by `input_lengths`. "
"If false, the lengths are equal to the max sequence length.");
DMLC_DECLARE_FIELD(use_label_lengths).set_default(false)
.describe("Whether the label lenghts are decided by "
"`label_lengths`, or derived from `padding_mask`. "
"If false, the lengths are derived from the "
"first occurrence of the value of `padding_mask`.");
DMLC_DECLARE_FIELD(padding_mask).set_default(dmlc::optional<int>(0))
.describe("int or None. This is the label value to be considered padding. "
"Only required when `use_label_lengths` is false. "
"Labels before the first occurrence of `padding_mask` are included "
"in calculation.");
}
};

template <typename xpu>
Expand All @@ -160,7 +196,7 @@ class CTCLossOp : public Operator {
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(in_data.size(), 2U+param_.use_input_lengths+param_.use_label_lengths);
CHECK_EQ(out_data.size(), 2U);
Stream<xpu> *s = ctx.get_stream<xpu>();

Expand All @@ -178,13 +214,26 @@ class CTCLossOp : public Operator {
int batch_size = data.size(1);
int alphabet_size = data.size(2);

// input_lengths
std::vector<int> input_lengths(batch_size, max_seq_len);
if (param_.use_input_lengths) {
int kInputLength = 2;
IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &input_lengths);
}

// label_lengths
std::vector<int> packed_labels;
std::vector<int> label_lengths;
LabelTensorToPackedVector(labels, &packed_labels, &label_lengths);
std::vector<int> label_lengths(batch_size);
if (param_.use_label_lengths) {
int kLabelLength = 2+param_.use_input_lengths;
PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, real_t>(s),
&packed_labels, &label_lengths);
} else {
LabelTensorToPackedVector(labels, param_.padding_mask.value(),
&packed_labels, &label_lengths);
}

// allocate temporary workspace
std::vector<int> input_lengths(batch_size, max_seq_len);
size_t size_bytes;
bool gpu = data.kDevCPU ? false : true;
get_workspace_size<real_t>(&label_lengths, &input_lengths, alphabet_size,
Expand Down Expand Up @@ -240,7 +289,15 @@ class CTCLossProp : public OperatorProperty {
int NumOutputs() const override { return 2; }

std::vector<std::string> ListArguments() const override {
return {"data", "label"};
if (param_.use_input_lengths && param_.use_label_lengths) {
return {"data", "label", "input_lengths", "label_lengths"};
} else if (param_.use_input_lengths) {
return {"data", "label", "input_lengths"};
} else if (param_.use_label_lengths) {
return {"data", "label", "label_lengths"};
} else {
return {"data", "label"};
}
}

std::vector<std::string> ListOutputs() const override {
Expand All @@ -259,14 +316,30 @@ class CTCLossProp : public OperatorProperty {
bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2U) << "Expect two inputs to the symbol.";
index_t expected_inputs = 2+param_.use_input_lengths+param_.use_label_lengths;
CHECK_EQ(in_shape->size(), expected_inputs)
<< "Expect " << expected_inputs << " inputs to the symbol.";

const TShape &dshape = (*in_shape)[ctc_loss::kData];
const TShape &lshape = (*in_shape)[ctc_loss::kLabel];
CHECK_EQ(dshape.ndim(), 3U) << "The data array must be of rank 3.";
CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2.";
CHECK_EQ(dshape[1], lshape[0])
<< "The batch size for the labels and data arrays must be the same.";
if (param_.use_input_lengths) {
int kInputLength = 2;
const TShape &dlshape = (*in_shape)[kInputLength];
CHECK_EQ(dlshape.ndim(), 1U) << "Input length array must be a vector.";
CHECK_EQ(dlshape[0], dshape[1])
<< "The batch size for the inputs and input lengths must be the same.";
}
if (param_.use_label_lengths) {
int kLabelLength = 2+param_.use_input_lengths;
const TShape &llshape = (*in_shape)[kLabelLength];
CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector.";
CHECK_EQ(llshape[0], lshape[0])
<< "The batch size for the labels and label lengths must be the same.";
}

CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed "
"the maximum sequence length of the "
Expand Down
6 changes: 6 additions & 0 deletions src/operator/contrib/ctc_loss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ information.
.add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.")
.add_argument("label", "NDArray-or-Symbol",
"Ground-truth labels for the loss.")
.add_argument("input_lengths", "NDArray-or-Symbol",
"Lengths of input for each of the samples. Only required "
"when use_input_lengths is true.")
.add_argument("label_lengths", "NDArray-or-Symbol",
"Lengths of labels for each of the samples. Only required "
"when use_label_lengths is true.")
.add_arguments(CTCLossParam::__FIELDS__());

NNVM_REGISTER_OP(_contrib_CTCLoss).add_alias("_contrib_ctc_loss");
Expand Down
18 changes: 10 additions & 8 deletions src/operator/sequence_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
namespace mxnet {
namespace op {

template <typename DType>
void IndexTensorToVector(mshadow::Tensor<gpu, 1, DType> data,
std::vector<index_t> *index_vec) {
template <typename DType, typename RType>
typename std::enable_if<std::is_integral<RType>::value>::type
IndexTensorToVector(mshadow::Tensor<gpu, 1, DType> data,
std::vector<RType> *index_vec) {
int max_seq_len = data.shape_.Size();
#if MXNET_USE_CUDA
DType *temp_index =
Expand All @@ -44,18 +45,19 @@ void IndexTensorToVector(mshadow::Tensor<gpu, 1, DType> data,
cudaMemcpyDeviceToHost, data.stream_->stream_);
CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error";
for (int i = 0; i < max_seq_len; ++i) {
(*index_vec)[i] = static_cast<index_t>(temp_index[i]);
(*index_vec)[i] = static_cast<RType>(temp_index[i]);
}
free(temp_index);
#endif
}
template <typename DType>
void IndexTensorToVector(mshadow::Tensor<cpu, 1, DType> data,
std::vector<index_t> *index_vec) {
template <typename DType, typename RType>
typename std::enable_if<std::is_integral<RType>::value>::type
IndexTensorToVector(mshadow::Tensor<cpu, 1, DType> data,
std::vector<RType> *index_vec) {
int max_seq_len = data.shape_.Size();
DType *index_array = static_cast<DType *>(data.dptr_);
for (int i = 0; i < max_seq_len; ++i)
(*index_vec)[i] = static_cast<index_t>(index_array[i]);
(*index_vec)[i] = static_cast<RType>(index_array[i]);
}

} // namespace op
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,36 @@ def test_l1_loss():
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1


def test_ctc_loss():
loss = gluon.loss.CTCLoss(padding_mask=0)
l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(output_layout='TNC', padding_mask=0)
l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(output_layout='TNC', label_layout='TN', padding_mask=0)
l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]]).T)
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(padding_mask=-1)
l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(use_label_lengths=True)
l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,2,2],[3,2,2,2]]), None, mx.nd.array([2,3]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(use_input_lengths=True)
l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]]), mx.nd.array([20,20]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(use_input_lengths=True, use_label_lengths=True)
l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,3,3],[3,2,2,3]]), mx.nd.array([20,20]), mx.nd.array([2,3]))
mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))


def test_sample_weight_loss():
mx.random.seed(1234)
np.random.seed(1234)
Expand Down

0 comments on commit 245a789

Please sign in to comment.