From f489810e0243aec05bc5107e94a9742cf55e1a1c Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Thu, 24 Aug 2017 12:02:41 -0700 Subject: [PATCH] contrib ctc interface changes, cudnn7 CTC, and gluon CTC (#7442) * contrib ctc interface changes for compatibility * cudnn ctc * update per comments --- python/mxnet/gluon/loss.py | 90 +++++++ src/operator/contrib/ctc_loss-inl.h | 331 ++++++++++++++++++++++---- src/operator/contrib/ctc_loss.cc | 12 +- src/operator/sequence_op_common.h | 18 +- tests/python/gpu/test_operator_gpu.py | 1 + tests/python/unittest/test_loss.py | 30 +++ 6 files changed, 430 insertions(+), 52 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 583910590868..bb45e8926e95 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -21,6 +21,8 @@ from __future__ import absolute_import from .. import ndarray +from ..contrib import symbol as symbol_contrib +from ..contrib import ndarray as ndarray_contrib from ..base import numeric_types from .block import HybridBlock @@ -295,3 +297,91 @@ def hybrid_forward(self, F, output, label, sample_weight=None): loss = label * (F.log(label+1e-8) - output) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) + + +class CTCLoss(Loss): + r"""Connectionist Temporal Classification Loss. + + See `"Connectionist Temporal Classification: Labelling Unsegmented + Sequence Data with Recurrent Neural Networks" + `_ paper for more information. + + Parameters + ---------- + layout : str, default 'NTC' + Layout of the output sequence activation vector. + label_layout : str, default 'NT' + Layout of the labels. + padding_mask : int or None, default -1 + This is the label value to be considered padding, which is used to derive the actual + lengths of labels. Only required when `label_lengths` is None. + weight : float or None + Global scalar weight for loss. + sample_weight : Symbol or None + Per sample weighting. Must be broadcastable to + the same shape as loss. For example, if loss has + shape (64, 10) and you want to weight each sample + in the batch, `sample_weight` should have shape (64, 1). + This should be used as the fifth argument when calling this loss. + + Input shapes: + `data` is an activation tensor without softmax. + Its shape depends on `layout`. For `layout='TNC'`, this + input has shape `(sequence_length, batch_size, alphabet_size)` + + `label` is the label index matrix. + Its shape depends on `label_layout`. For `label_layout='TN'`, this + input has shape `(label_sequence_length, batch_size)` + When `label_lengths` is not specified, the first occurrence of `padding_mask` + in each sample marks the end of the label sequence of that sample. + For example, suppose there are two samples, with *label_sequence_length* = 4. + The two sequences of labels are [2, 1] and [3, 2, 2], and their actual lengths + are smaller than 4. Thus, given *padding_mask* = 0, the resulting ```label``` + tensor should be padded to be:: + + [[2, 1, 0, 0], [3, 2, 2, 0]] + + `data_lengths` is optional and defaults to None. + When specified, it represents the actual lengths of data. + The shape should be (batch_size,). + If None, the data lengths are treated as being equal to the max sequence length. + This should be used as the third argument when calling this loss. + + `label_lengths` is optional and defaults to None. + When specified, it represents the actual lengths of labels. + The shape should be (batch_size,). + If None, the label lengths are derived from the first occurrence of + the value specified by `padding_mask`. + This should be used as the fourth argument when calling this loss. + + Output shape: + The CTC loss output has the shape (batch_size,). + """ + def __init__(self, layout='NTC', label_layout='NT', padding_mask=-1, + weight=None, **kwargs): + assert layout in ['NTC', 'TNC'],\ + "Only 'NTC' and 'TNC' layouts for output are supported. Got: %s"%layout + assert label_layout in ['NT', 'TN'],\ + "Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout + self._layout = layout + self._label_layout = label_layout + self._padding_mask = padding_mask + batch_axis = label_layout.find('N') + super(CTCLoss, self).__init__(weight, batch_axis, **kwargs) + + def hybrid_forward(self, F, data, label, + data_lengths=None, label_lengths=None, sample_weight=None): + if self._layout == 'NTC': + data = F.swapaxes(data, 0, 1) + if self._batch_axis == 1: + label = F.swapaxes(label, 0, 1) + if F is ndarray: + F_contrib = ndarray_contrib + else: + F_contrib = symbol_contrib + loss = F_contrib.CTCLoss(data, label, + use_data_lengths=data_lengths is not None, + use_label_lengths=label_lengths is not None, + data_lengths=data_lengths, label_lengths=label_lengths, + padding_mask=self._padding_mask) + return _apply_weighting(F, loss, self._weight, sample_weight) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index 0d0c0bf4cd09..13ce1f240afd 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -41,6 +41,11 @@ #include "../sequence_op_common.h" #include "../mshadow_op.h" +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +#define CUDNN_LABEL_LENGTH_LIMIT 256 +#include "../nn/softmax-inl.h" +#endif + namespace mxnet { namespace op { @@ -52,14 +57,14 @@ enum CTCLossOpForwardResource { kTempSpace }; template inline void get_workspace_size(std::vector *label_lengths, - std::vector *input_lengths, + std::vector *data_lengths, int alphabet_size, int minibatch, bool gpu, size_t *size_bytes) { // This is the max of all S and T for all examples in the minibatch. int maxL = *std::max_element(label_lengths->data(), label_lengths->data() + minibatch); - int maxT = *std::max_element(input_lengths->data(), - input_lengths->data() + minibatch); + int maxT = *std::max_element(data_lengths->data(), + data_lengths->data() + minibatch); const int S = 2 * maxL + 1; @@ -125,34 +130,109 @@ inline void get_workspace_size(std::vector *label_lengths, } // Takes a tensor of labels, and interprets 0-elements at the end of the vector -// as padding. The tensor is packed into a std::vector without padding -// characters. The sequence lengths are also inferred from the padding chars +// as padding. The tensor is packed into an std::vector without padding +// characters. The label sequence lengths are also inferred from the padding chars. +// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded. template -inline void LabelTensorToPackedVector(mshadow::Tensor labels, +inline bool LabelTensorToPackedVector(mshadow::Tensor labels, + int padding_mask, std::vector *packed_labels, std::vector *label_lengths) { int batch = labels.size(0); int max_num_labels = labels.size(1); - std::vector cpu_labels(max_num_labels); + bool exceed_limit = false; + + std::vector cpu_labels(max_num_labels*batch); + mshadow::Tensor flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); + + for (int b = 0; b < batch; ++b) { + auto start = cpu_labels.data()+b*max_num_labels; + auto res = std::find(start, start+max_num_labels, padding_mask); + int len = std::distance(start, res); +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; +#endif + std::copy(start, start + len, + std::back_inserter(*packed_labels)); + label_lengths->at(b) = len; + } + return exceed_limit; +} + +// Takes a tensor of labels, and a vector which specifies the actual length of each label +// The tensor is packed into an std::vector without padding characters. +// The label length vector is copied into an std::vector. +// When cudnn is enabled, the return value signifies whether the cudnn length limit is exceeded. +template +inline bool PackLabelByLength(mshadow::Tensor labels, + mshadow::Tensor in_label_lengths, + std::vector *packed_labels, + std::vector *label_lengths) { + int batch = labels.size(0); + int max_num_labels = labels.size(1); + bool exceed_limit = false; + + IndexTensorToVector(in_label_lengths, label_lengths); + + std::vector cpu_labels(max_num_labels*batch); + mshadow::Tensor flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); for (int b = 0; b < batch; ++b) { - IndexTensorToVector(labels[b], &cpu_labels); - auto res = std::find(cpu_labels.begin(), cpu_labels.end(), 0); - int len = std::distance(cpu_labels.begin(), res); - std::copy(cpu_labels.begin(), cpu_labels.begin() + len, + auto start = cpu_labels.data()+b*max_num_labels; + int len = label_lengths->at(b); +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; +#endif + std::copy(start, start + len, std::back_inserter(*packed_labels)); - label_lengths->emplace_back(len); } + return exceed_limit; } struct CTCLossParam : public dmlc::Parameter { - DMLC_DECLARE_PARAMETER(CTCLossParam) {} + bool use_data_lengths; + bool use_label_lengths; + dmlc::optional padding_mask; + DMLC_DECLARE_PARAMETER(CTCLossParam) { + DMLC_DECLARE_FIELD(use_data_lengths).set_default(false) + .describe("Whether the data lenghts are decided by `data_lengths`. " + "If false, the lengths are equal to the max sequence length."); + DMLC_DECLARE_FIELD(use_label_lengths).set_default(false) + .describe("Whether the label lenghts are decided by " + "`label_lengths`, or derived from `padding_mask`. " + "If false, the lengths are derived from the " + "first occurrence of the value of `padding_mask`."); + DMLC_DECLARE_FIELD(padding_mask).set_default(dmlc::optional(0)) + .describe("int or None. This is the label value to be considered padding. " + "Only required when `use_label_lengths` is false. " + "Labels before the first occurrence of `padding_mask` are included " + "in calculation."); + } }; template class CTCLossOp : public Operator { public: - explicit CTCLossOp(CTCLossParam p) { this->param_ = p; } + explicit CTCLossOp(CTCLossParam p) { + this->param_ = p; + exceed_cudnn_limit = false; +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctc_desc_)); + CUDNN_CALL(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&prob_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&grad_desc_)); +#endif + } + + ~CTCLossOp() { +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctc_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(prob_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(grad_desc_)); +#endif + } virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -160,8 +240,9 @@ class CTCLossOp : public Operator { const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2U); + CHECK_EQ(in_data.size(), 2U+param_.use_data_lengths+param_.use_label_lengths); CHECK_EQ(out_data.size(), 2U); + exceed_cudnn_limit = false; Stream *s = ctx.get_stream(); Tensor data = @@ -178,27 +259,41 @@ class CTCLossOp : public Operator { int batch_size = data.size(1); int alphabet_size = data.size(2); + // data_lengths + std::vector data_lengths(batch_size, max_seq_len); + if (param_.use_data_lengths) { + int kInputLength = 2; + IndexTensorToVector(in_data[kInputLength].get(s), &data_lengths); + } + // label_lengths std::vector packed_labels; - std::vector label_lengths; - LabelTensorToPackedVector(labels, &packed_labels, &label_lengths); - - // allocate temporary workspace - std::vector input_lengths(batch_size, max_seq_len); - size_t size_bytes; - bool gpu = data.kDevCPU ? false : true; - get_workspace_size(&label_lengths, &input_lengths, alphabet_size, - batch_size, gpu, &size_bytes); - - // round-up so there are enough elems in memory - int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t); - Tensor workspace = - ctx.requested[ctc_loss::kTempSpace].get_space_typed( - Shape1(num_tmp_elems), s); - - compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels.data(), - label_lengths.data(), input_lengths.data(), - workspace.dptr_, ctx.is_train); + std::vector label_lengths(batch_size); + + if (param_.use_label_lengths) { + int kLabelLength = 2+param_.use_data_lengths; + exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get(s), + &packed_labels, &label_lengths); + } else { + exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.padding_mask.value(), + &packed_labels, &label_lengths); + } + +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + if (!param_.use_data_lengths && !exceed_cudnn_limit) { + cudnn_forward(ctx, s, data, costs, grad, + &data_lengths, &label_lengths, &packed_labels, + max_seq_len, batch_size, alphabet_size); + } else { + baidu_forward(ctx, s, data, costs, grad, + &data_lengths, &label_lengths, &packed_labels, + batch_size, alphabet_size); + } +#else + baidu_forward(ctx, s, data, costs, grad, + &data_lengths, &label_lengths, &packed_labels, + batch_size, alphabet_size); +#endif // __CUDACC__ && CUDNN } virtual void Backward(const OpContext &ctx, @@ -221,12 +316,143 @@ class CTCLossOp : public Operator { Tensor data_grad_computed = out_data[ctc_loss::kGrad].get(s); - Assign(data_grad, req[ctc_loss::kData], - broadcast<1>(output_grad, data_grad.shape_) * data_grad_computed); +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + if (!param_.use_data_lengths && !exceed_cudnn_limit) { + cudnn_backward_extra(s, data_grad, output_grad, data_grad_computed); + } else { + baidu_backward_extra(req, data_grad, output_grad, data_grad_computed); + } +#else + baidu_backward_extra(req, data_grad, output_grad, data_grad_computed); +#endif } private: CTCLossParam param_; + bool exceed_cudnn_limit; + +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + cudnnDataType_t dtype_; + cudnnCTCLossDescriptor_t ctc_desc_; + cudnnTensorDescriptor_t prob_desc_, grad_desc_; + + inline virtual void cudnn_forward(const OpContext &ctx, + mshadow::Stream* s, + mshadow::Tensor data, + mshadow::Tensor costs, + mshadow::Tensor grad, + std::vector* data_lengths, + std::vector* label_lengths, + std::vector* packed_labels, + int max_seq_len, + int batch_size, + int alphabet_size) { + using namespace mshadow; + + // call cudnn to calculate ctc loss + dtype_ = CUDNN_DATA_FLOAT; + int dims[3], strides[3]; + size_t workspace_bytes; + int workspace_size; + dims[0] = max_seq_len; + dims[1] = batch_size; + dims[2] = alphabet_size; + strides[0] = batch_size*alphabet_size; + strides[1] = alphabet_size; + strides[2] = 1; + cudnnCTCLossAlgo_t ctc_algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC; + CUDNN_CALL(cudnnSetTensorNdDescriptor(prob_desc_, + dtype_, + 3, + dims, + strides)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(grad_desc_, + dtype_, + 3, + dims, + strides)); + CUDNN_CALL(cudnnGetCTCLossWorkspaceSize(s->dnn_handle_, + prob_desc_, + grad_desc_, + packed_labels->data(), + label_lengths->data(), + data_lengths->data(), + ctc_algo, + ctc_desc_, + &workspace_bytes)); + workspace_size = workspace_bytes/sizeof(real_t); + + Tensor temp_space = + ctx.requested[ctc_loss::kTempSpace].get_space_typed( + mshadow::Shape1(workspace_size+data.shape_.FlatTo1D()[0]), s); + + Tensor work_space(temp_space.dptr_, + mshadow::Shape1(workspace_size), s); + Tensor prob(temp_space.dptr_+workspace_size, + data.shape_, s); + + // since the input is activation before softmax and cudnn ctc takes softmax + // apply softmax to inputs first. + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); + + CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, + prob_desc_, + prob.dptr_, + packed_labels->data(), + label_lengths->data(), + data_lengths->data(), + costs.dptr_, + grad_desc_, + grad.dptr_, + ctc_algo, + ctc_desc_, + work_space.dptr_, + workspace_bytes)); + } + inline virtual void cudnn_backward_extra(mshadow::Stream* s, + mshadow::Tensor data_grad, + mshadow::Tensor output_grad, + mshadow::Tensor data_grad_computed) { + mxnet_op::SoftmaxGrad(s, + output_grad.dptr_, data_grad_computed.dptr_, data_grad.dptr_, data_grad.shape_, 2); + } +#endif // __CUDACC__ && CUDNN + + inline virtual void baidu_forward(const OpContext &ctx, + mshadow::Stream* s, + mshadow::Tensor data, + mshadow::Tensor costs, + mshadow::Tensor grad, + std::vector* data_lengths, + std::vector* label_lengths, + std::vector* packed_labels, + int batch_size, + int alphabet_size) { + using namespace mshadow; + // allocate temporary workspace + size_t size_bytes; + bool gpu = data.kDevCPU ? false : true; + get_workspace_size(label_lengths, data_lengths, alphabet_size, + batch_size, gpu, &size_bytes); + + // round-up so there are enough elems in memory + int num_tmp_elems = (size_bytes + sizeof(real_t) - 1) / sizeof(real_t); + Tensor workspace = + ctx.requested[ctc_loss::kTempSpace].get_space_typed( + Shape1(num_tmp_elems), s); + + compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(), + label_lengths->data(), data_lengths->data(), + workspace.dptr_, ctx.is_train); + } + + inline virtual void baidu_backward_extra(const std::vector &req, + mshadow::Tensor data_grad, + mshadow::Tensor output_grad, + mshadow::Tensor data_grad_computed) { + Assign(data_grad, req[ctc_loss::kData], + mshadow::expr::broadcast<1>(output_grad, data_grad.shape_) * data_grad_computed); + } }; // class CTCLossOp template @@ -240,15 +466,22 @@ class CTCLossProp : public OperatorProperty { int NumOutputs() const override { return 2; } std::vector ListArguments() const override { - return {"data", "label"}; + if (param_.use_data_lengths && param_.use_label_lengths) { + return {"data", "label", "data_lengths", "label_lengths"}; + } else if (param_.use_data_lengths) { + return {"data", "label", "data_lengths"}; + } else if (param_.use_label_lengths) { + return {"data", "label", "label_lengths"}; + } else { + return {"data", "label"}; + } } std::vector ListOutputs() const override { return {"output", "grad"}; } - void Init( - const std::vector> &kwargs) override { + void Init(const std::vector> &kwargs) override { param_.Init(kwargs); } @@ -259,7 +492,9 @@ class CTCLossProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; - CHECK_EQ(in_shape->size(), 2U) << "Expect two inputs to the symbol."; + index_t expected_inputs = 2+param_.use_data_lengths+param_.use_label_lengths; + CHECK_EQ(in_shape->size(), expected_inputs) + << "Expect " << expected_inputs << " inputs to the symbol."; const TShape &dshape = (*in_shape)[ctc_loss::kData]; const TShape &lshape = (*in_shape)[ctc_loss::kLabel]; @@ -267,10 +502,24 @@ class CTCLossProp : public OperatorProperty { CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2."; CHECK_EQ(dshape[1], lshape[0]) << "The batch size for the labels and data arrays must be the same."; + if (param_.use_data_lengths) { + int kInputLength = 2; + const TShape &dlshape = (*in_shape)[kInputLength]; + CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector."; + CHECK_EQ(dlshape[0], dshape[1]) + << "The batch size for the data and data lengths must be the same."; + } + if (param_.use_label_lengths) { + int kLabelLength = 2+param_.use_data_lengths; + const TShape &llshape = (*in_shape)[kLabelLength]; + CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector."; + CHECK_EQ(llshape[0], lshape[0]) + << "The batch size for the labels and label lengths must be the same."; + } CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed " "the maximum sequence length of the " - "input."; + "data."; TShape oshape(1); oshape[0] = dshape[1]; // batch size diff --git a/src/operator/contrib/ctc_loss.cc b/src/operator/contrib/ctc_loss.cc index 3727cee10b1c..d544a1fdec04 100644 --- a/src/operator/contrib/ctc_loss.cc +++ b/src/operator/contrib/ctc_loss.cc @@ -31,7 +31,7 @@ namespace mshadow { template ctcStatus_t compute_ctc_cost(const Tensor activations, DType *costs, DType *grads, int *labels, - int *label_lengths, int *input_lengths, + int *label_lengths, int *data_lengths, void *workspace, int train) { int minibatch = static_cast(activations.size(1)); int alphabet_size = static_cast(activations.size(2)); @@ -39,10 +39,10 @@ ctcStatus_t compute_ctc_cost(const Tensor activations, mxnet_warpctc::CpuCTC ctc(alphabet_size, minibatch, workspace, blank_label); if (train) return ctc.cost_and_grad(activations.dptr_, grads, costs, labels, - label_lengths, input_lengths); + label_lengths, data_lengths); else return ctc.score_forward(activations.dptr_, costs, labels, label_lengths, - input_lengths); + data_lengths); } } // namespace mshadow @@ -100,6 +100,12 @@ information. .add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.") .add_argument("label", "NDArray-or-Symbol", "Ground-truth labels for the loss.") + .add_argument("data_lengths", "NDArray-or-Symbol", + "Lengths of data for each of the samples. Only required " + "when use_data_lengths is true.") + .add_argument("label_lengths", "NDArray-or-Symbol", + "Lengths of labels for each of the samples. Only required " + "when use_label_lengths is true.") .add_arguments(CTCLossParam::__FIELDS__()); NNVM_REGISTER_OP(_contrib_CTCLoss).add_alias("_contrib_ctc_loss"); diff --git a/src/operator/sequence_op_common.h b/src/operator/sequence_op_common.h index 9e5843161087..724e0e0da121 100644 --- a/src/operator/sequence_op_common.h +++ b/src/operator/sequence_op_common.h @@ -32,9 +32,10 @@ namespace mxnet { namespace op { -template -void IndexTensorToVector(mshadow::Tensor data, - std::vector *index_vec) { +template +typename std::enable_if::value>::type +IndexTensorToVector(mshadow::Tensor data, + std::vector *index_vec) { int max_seq_len = data.shape_.Size(); #if MXNET_USE_CUDA DType *temp_index = @@ -44,18 +45,19 @@ void IndexTensorToVector(mshadow::Tensor data, cudaMemcpyDeviceToHost, data.stream_->stream_); CHECK_EQ(cuda_status, cudaSuccess) << "cuda memcpy label error"; for (int i = 0; i < max_seq_len; ++i) { - (*index_vec)[i] = static_cast(temp_index[i]); + (*index_vec)[i] = static_cast(temp_index[i]); } free(temp_index); #endif } -template -void IndexTensorToVector(mshadow::Tensor data, - std::vector *index_vec) { +template +typename std::enable_if::value>::type +IndexTensorToVector(mshadow::Tensor data, + std::vector *index_vec) { int max_seq_len = data.shape_.Size(); DType *index_array = static_cast(data.dptr_); for (int i = 0; i < max_seq_len; ++i) - (*index_vec)[i] = static_cast(index_array[i]); + (*index_vec)[i] = static_cast(index_array[i]); } } // namespace op diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 35a20f935573..11d146cae840 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -29,6 +29,7 @@ from test_optimizer import * from test_random import * from test_gluon import * +from test_loss import * #from test_rnn import * from test_gluon_rnn import * from test_sparse_operator import test_cast_storage_ex, test_sparse_dot diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 714ea7562fdb..b864215ca1d1 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -165,6 +165,36 @@ def test_l1_loss(): assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.1 +def test_ctc_loss(): + loss = gluon.loss.CTCLoss(padding_mask=0) + l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss(layout='TNC', padding_mask=0) + l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN', padding_mask=0) + l = loss(mx.nd.ones((20,2,4)), mx.nd.array([[2,1,0,0],[3,2,2,0]]).T) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss(padding_mask=-1) + l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,2,2],[3,2,2,2]]), None, mx.nd.array([2,3])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]]), mx.nd.array([20,20])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + loss = gluon.loss.CTCLoss() + l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,3,3],[3,2,2,3]]), mx.nd.array([20,20]), mx.nd.array([2,3])) + mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) + + def test_sample_weight_loss(): mx.random.seed(1234) np.random.seed(1234)