From 3cfbfeb1b8cadf4a30f8a621ce608af40a4edf2f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 18 Aug 2017 00:46:11 -0700 Subject: [PATCH] update per comments --- python/mxnet/gluon/loss.py | 79 ++++++++-------- src/operator/contrib/ctc_loss-inl.h | 139 ++++++++++++++-------------- src/operator/contrib/ctc_loss.cc | 12 +-- tests/python/unittest/test_loss.py | 6 +- 4 files changed, 120 insertions(+), 116 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index a2d25d04e99c..a40ad3074aa1 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -306,13 +306,7 @@ class CTCLoss(Loss): Sequence Data with Recurrent Neural Networks" `_ paper for more information. - The prediction output should be an activation vector without softmax, with shape - according to the output_layout: - **TNC**: *(sequence_length, batch_size, alphabet_size + 1)* - **NTC**: *(batch_size, sequence_length, alphabet_size + 1)* - The loss output has the shape: - **loss**: *(batch_size,)*. ``label`` is a tensor of integers between 1 and *alphabet_size*, with shape according to the label_layout: @@ -330,69 +324,74 @@ class CTCLoss(Loss): Parameters ---------- - output_layout : str, default 'NTC' + layout : str, default 'NTC' Layout of the output sequence activation vector. label_layout : str, default 'NT' Layout of the labels. - use_input_lengths : bool, default False - Whether to use `input_lengths` to decide lengths of inputs. - If false, the input lengths are treated as being equal to the max sequence length. - use_label_lengths : bool, default False - Whether to use `label_lengths` to decide lengths of labels. - If false, the label lengths are derived from the first occurrence of - the value specified by `padding_mask`. padding_mask : int or None, default -1 This is the label value to be considered padding, which is used to derive the actual - lengths of labels. Only required when `use_label_lengths` is false. + lengths of labels. Only required when `label_lengths` is None. weight : float or None Global scalar weight for loss. - input_lengths : NDArray or None, - Actual lengths of inputs. Only required when `use_input_lengths` is true. - This should be used as the third argument when calling this loss. - The shape should be (N,). - label_lengths : NDArray or None, - Lengths of labels. Only required when `use_label_lengths` is true. - This should be used as the fourth argument when calling this loss. - The shape should be (N,). sample_weight : Symbol or None Per sample weighting. Must be broadcastable to the same shape as loss. For example, if loss has shape (64, 10) and you want to weight each sample in the batch, `sample_weight` should have shape (64, 1). This should be used as the fifth argument when calling this loss. + + Input shapes: + `data` is an activation tensor without softmax. + Its shape depends on `layout`. For `layout='TNC'`, this + input has shape `(sequence_length, batch_size, alphabet_size)` + + `label` is the label index matrix. + Its shape depends on `layout`. For `layout='TNC'`, this + input has shape `(sequence_length, batch_size, alphabet_size)` + When `label_lengths` is not specified, the first occurrence of `padding_mask` + in each sample marks the end of the label sequence of that sample. + + `data_lengths` is optional and defaults to None. + When specified, it represents the actual lengths of data. + The shape should be (batch_size,). + If None, the data lengths are treated as being equal to the max sequence length. + This should be used as the third argument when calling this loss. + + `label_lengths` is optional and defaults to None. + When specified, it represents the actual lengths of labels. + The shape should be (batch_size,). + If None, the label lengths are derived from the first occurrence of + the value specified by `padding_mask`. + This should be used as the fourth argument when calling this loss. + + Output shape: + The CTC loss output has the shape (batch_size,). """ - def __init__(self, output_layout='NTC', label_layout='NT', - use_input_lengths=False, use_label_lengths=False, padding_mask=-1, + def __init__(self, layout='NTC', label_layout='NT', padding_mask=-1, weight=None, **kwargs): - assert output_layout in ['NTC', 'TNC'],\ - "Only 'NTC' and 'TNC' layouts for output are supported. Got: %s"%output_layout + assert layout in ['NTC', 'TNC'],\ + "Only 'NTC' and 'TNC' layouts for output are supported. Got: %s"%layout assert label_layout in ['NT', 'TN'],\ "Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout - self._output_layout = output_layout + self._layout = layout self._label_layout = label_layout - self._use_input_lengths = use_input_lengths - self._use_label_lengths = use_label_lengths self._padding_mask = padding_mask batch_axis = label_layout.find('N') super(CTCLoss, self).__init__(weight, batch_axis, **kwargs) - def hybrid_forward(self, F, output, label, + def hybrid_forward(self, F, data, label, input_lengths=None, label_lengths=None, sample_weight=None): - assert not self._use_input_lengths or input_lengths is not None, \ - "Must specify input_lengths." - assert not self._use_label_lengths or label_lengths is not None, \ - "Must specify label_lengths." - if self._output_layout == 'NTC': - output = F.swapaxes(output, 0, 1) + if self._layout == 'NTC': + data = F.swapaxes(data, 0, 1) if self._batch_axis == 1: label = F.swapaxes(label, 0, 1) if F is ndarray: F_contrib = ndarray_contrib else: F_contrib = symbol_contrib - loss = F_contrib.CTCLoss(output, label, - use_input_lengths=self._use_input_lengths, - use_label_lengths=self._use_label_lengths, + loss = F_contrib.CTCLoss(data, label, + use_input_lengths=input_lengths is not None, + use_label_lengths=label_lengths is not None, input_lengths=input_lengths, label_lengths=label_lengths, padding_mask=self._padding_mask) return _apply_weighting(F, loss, self._weight, sample_weight) diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index 489f727712f9..4ae4d7e1298a 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -57,14 +57,14 @@ enum CTCLossOpForwardResource { kTempSpace }; template inline void get_workspace_size(std::vector *label_lengths, - std::vector *input_lengths, + std::vector *data_lengths, int alphabet_size, int minibatch, bool gpu, size_t *size_bytes) { // This is the max of all S and T for all examples in the minibatch. int maxL = *std::max_element(label_lengths->data(), label_lengths->data() + minibatch); - int maxT = *std::max_element(input_lengths->data(), - input_lengths->data() + minibatch); + int maxT = *std::max_element(data_lengths->data(), + data_lengths->data() + minibatch); const int S = 2 * maxL + 1; @@ -140,17 +140,20 @@ inline bool LabelTensorToPackedVector(mshadow::Tensor labels, std::vector *label_lengths) { int batch = labels.size(0); int max_num_labels = labels.size(1); - std::vector cpu_labels(max_num_labels); bool exceed_limit = false; + std::vector cpu_labels(max_num_labels*batch); + mshadow::Tensor flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); + for (int b = 0; b < batch; ++b) { - IndexTensorToVector(labels[b], &cpu_labels); - auto res = std::find(cpu_labels.begin(), cpu_labels.end(), padding_mask); - int len = std::distance(cpu_labels.begin(), res); + auto start = cpu_labels.data()+b*max_num_labels; + auto res = std::find(start, start+max_num_labels, padding_mask); + int len = std::distance(start, res); #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; #endif - std::copy(cpu_labels.begin(), cpu_labels.begin() + len, + std::copy(start, start + len, std::back_inserter(*packed_labels)); label_lengths->at(b) = len; } @@ -168,29 +171,33 @@ inline bool PackLabelByLength(mshadow::Tensor labels, std::vector *label_lengths) { int batch = labels.size(0); int max_num_labels = labels.size(1); - std::vector cpu_labels(max_num_labels); - IndexTensorToVector(in_label_lengths, label_lengths); bool exceed_limit = false; + IndexTensorToVector(in_label_lengths, label_lengths); + + std::vector cpu_labels(max_num_labels); + mshadow::Tensor flat_labels = labels.FlatTo1D(); + IndexTensorToVector(flat_labels, &cpu_labels); + for (int b = 0; b < batch; ++b) { - IndexTensorToVector(labels[b], &cpu_labels); + auto start = cpu_labels.data()+b*max_num_labels; int len = label_lengths->at(b); #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 exceed_limit = exceed_limit || len > CUDNN_LABEL_LENGTH_LIMIT; #endif - std::copy(cpu_labels.begin(), cpu_labels.begin() + len, + std::copy(start, start + len, std::back_inserter(*packed_labels)); } return exceed_limit; } struct CTCLossParam : public dmlc::Parameter { - bool use_input_lengths; + bool use_data_lengths; bool use_label_lengths; dmlc::optional padding_mask; DMLC_DECLARE_PARAMETER(CTCLossParam) { - DMLC_DECLARE_FIELD(use_input_lengths).set_default(false) - .describe("Whether the input lenghts are decided by `input_lengths`. " + DMLC_DECLARE_FIELD(use_data_lengths).set_default(false) + .describe("Whether the data lenghts are decided by `data_lengths`. " "If false, the lengths are equal to the max sequence length."); DMLC_DECLARE_FIELD(use_label_lengths).set_default(false) .describe("Whether the label lenghts are decided by " @@ -210,15 +217,17 @@ class CTCLossOp : public Operator { public: explicit CTCLossOp(CTCLossParam p) { this->param_ = p; -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 exceed_cudnn_limit = false; +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 CUDNN_CALL(cudnnCreateCTCLossDescriptor(&ctc_desc_)); CUDNN_CALL(cudnnSetCTCLossDescriptor(ctc_desc_, CUDNN_DATA_FLOAT)); CUDNN_CALL(cudnnCreateTensorDescriptor(&prob_desc_)); CUDNN_CALL(cudnnCreateTensorDescriptor(&grad_desc_)); +#endif } ~CTCLossOp() { +#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 CUDNN_CALL(cudnnDestroyCTCLossDescriptor(ctc_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(prob_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(grad_desc_)); @@ -231,11 +240,9 @@ class CTCLossOp : public Operator { const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2U+param_.use_input_lengths+param_.use_label_lengths); + CHECK_EQ(in_data.size(), 2U+param_.use_data_lengths+param_.use_label_lengths); CHECK_EQ(out_data.size(), 2U); -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 exceed_cudnn_limit = false; -#endif Stream *s = ctx.get_stream(); Tensor data = @@ -252,11 +259,11 @@ class CTCLossOp : public Operator { int batch_size = data.size(1); int alphabet_size = data.size(2); - // input_lengths - std::vector input_lengths(batch_size, max_seq_len); - if (param_.use_input_lengths) { + // data_lengths + std::vector data_lengths(batch_size, max_seq_len); + if (param_.use_data_lengths) { int kInputLength = 2; - IndexTensorToVector(in_data[kInputLength].get(s), &input_lengths); + IndexTensorToVector(in_data[kInputLength].get(s), &data_lengths); } // label_lengths @@ -264,33 +271,27 @@ class CTCLossOp : public Operator { std::vector label_lengths(batch_size); if (param_.use_label_lengths) { - int kLabelLength = 2+param_.use_input_lengths; -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - exceed_cudnn_limit = -#endif - PackLabelByLength(labels, in_data[kLabelLength].get(s), - &packed_labels, &label_lengths); + int kLabelLength = 2+param_.use_data_lengths; + exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get(s), + &packed_labels, &label_lengths); } else { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - exceed_cudnn_limit = -#endif - LabelTensorToPackedVector(labels, param_.padding_mask.value(), - &packed_labels, &label_lengths); + exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.padding_mask.value(), + &packed_labels, &label_lengths); } #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - if (!param_.use_input_lengths && !exceed_cudnn_limit) { + if (!param_.use_data_lengths && !exceed_cudnn_limit) { cudnn_forward(ctx, s, data, costs, grad, - &input_lengths, &label_lengths, &packed_labels, + &data_lengths, &label_lengths, &packed_labels, max_seq_len, batch_size, alphabet_size); } else { baidu_forward(ctx, s, data, costs, grad, - &input_lengths, &label_lengths, &packed_labels, + &data_lengths, &label_lengths, &packed_labels, batch_size, alphabet_size); } #else baidu_forward(ctx, s, data, costs, grad, - &input_lengths, &label_lengths, &packed_labels, + &data_lengths, &label_lengths, &packed_labels, batch_size, alphabet_size); #endif // __CUDACC__ && CUDNN } @@ -316,7 +317,7 @@ class CTCLossOp : public Operator { out_data[ctc_loss::kGrad].get(s); #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - if (!param_.use_input_lengths && !exceed_cudnn_limit) { + if (!param_.use_data_lengths && !exceed_cudnn_limit) { cudnn_backward_extra(s, data_grad, output_grad, data_grad_computed); } else { baidu_backward_extra(req, data_grad, output_grad, data_grad_computed); @@ -328,9 +329,9 @@ class CTCLossOp : public Operator { private: CTCLossParam param_; + bool exceed_cudnn_limit; #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 - bool exceed_cudnn_limit; cudnnDataType_t dtype_; cudnnCTCLossDescriptor_t ctc_desc_; cudnnTensorDescriptor_t prob_desc_, grad_desc_; @@ -340,19 +341,13 @@ class CTCLossOp : public Operator { mshadow::Tensor data, mshadow::Tensor costs, mshadow::Tensor grad, - std::vector* input_lengths, + std::vector* data_lengths, std::vector* label_lengths, std::vector* packed_labels, int max_seq_len, int batch_size, int alphabet_size) { using namespace mshadow; - // since the input is activation before softmax and cudnn ctc takes softmax - // apply softmax to inputs first. - Tensor prob(data.shape_); - mshadow::AllocSpace(&prob); - prob.set_stream(s); - mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); // call cudnn to calculate ctc loss dtype_ = CUDNN_DATA_FLOAT; @@ -381,28 +376,38 @@ class CTCLossOp : public Operator { grad_desc_, packed_labels->data(), label_lengths->data(), - input_lengths->data(), + data_lengths->data(), ctc_algo, ctc_desc_, &workspace_bytes)); workspace_size = workspace_bytes/sizeof(real_t); - Tensor temp_space = - ctx.requested[ctc_loss::kTempSpace].get_space_typed( - mshadow::Shape1(workspace_size), s); + + Tensor temp_space = + ctx.requested[ctc_loss::kTempSpace].get_space_typed( + mshadow::Shape1(workspace_size+data.shape_.FlatTo1D()[0]), s); + + Tensor work_space(temp_space.dptr_, + mshadow::Shape1(workspace_size), s); + Tensor prob(temp_space.dptr_+workspace_size, + data.shape_, s); + + // since the input is activation before softmax and cudnn ctc takes softmax + // apply softmax to inputs first. + mxnet_op::Softmax(s, data.dptr_, prob.dptr_, data.shape_, 2); + CUDNN_CALL(cudnnCTCLoss(s->dnn_handle_, prob_desc_, prob.dptr_, packed_labels->data(), label_lengths->data(), - input_lengths->data(), + data_lengths->data(), costs.dptr_, grad_desc_, grad.dptr_, ctc_algo, ctc_desc_, - temp_space.dptr_, + work_space.dptr_, workspace_bytes)); - mshadow::FreeSpace(&prob); } inline virtual void cudnn_backward_extra(mshadow::Stream* s, mshadow::Tensor data_grad, @@ -418,7 +423,7 @@ class CTCLossOp : public Operator { mshadow::Tensor data, mshadow::Tensor costs, mshadow::Tensor grad, - std::vector* input_lengths, + std::vector* data_lengths, std::vector* label_lengths, std::vector* packed_labels, int batch_size, @@ -427,7 +432,7 @@ class CTCLossOp : public Operator { // allocate temporary workspace size_t size_bytes; bool gpu = data.kDevCPU ? false : true; - get_workspace_size(label_lengths, input_lengths, alphabet_size, + get_workspace_size(label_lengths, data_lengths, alphabet_size, batch_size, gpu, &size_bytes); // round-up so there are enough elems in memory @@ -437,7 +442,7 @@ class CTCLossOp : public Operator { Shape1(num_tmp_elems), s); compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(), - label_lengths->data(), input_lengths->data(), + label_lengths->data(), data_lengths->data(), workspace.dptr_, ctx.is_train); } @@ -461,10 +466,10 @@ class CTCLossProp : public OperatorProperty { int NumOutputs() const override { return 2; } std::vector ListArguments() const override { - if (param_.use_input_lengths && param_.use_label_lengths) { - return {"data", "label", "input_lengths", "label_lengths"}; - } else if (param_.use_input_lengths) { - return {"data", "label", "input_lengths"}; + if (param_.use_data_lengths && param_.use_label_lengths) { + return {"data", "label", "data_lengths", "label_lengths"}; + } else if (param_.use_data_lengths) { + return {"data", "label", "data_lengths"}; } else if (param_.use_label_lengths) { return {"data", "label", "label_lengths"}; } else { @@ -487,7 +492,7 @@ class CTCLossProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; - index_t expected_inputs = 2+param_.use_input_lengths+param_.use_label_lengths; + index_t expected_inputs = 2+param_.use_data_lengths+param_.use_label_lengths; CHECK_EQ(in_shape->size(), expected_inputs) << "Expect " << expected_inputs << " inputs to the symbol."; @@ -497,15 +502,15 @@ class CTCLossProp : public OperatorProperty { CHECK_EQ(lshape.ndim(), 2U) << "The labels array must be of rank 2."; CHECK_EQ(dshape[1], lshape[0]) << "The batch size for the labels and data arrays must be the same."; - if (param_.use_input_lengths) { + if (param_.use_data_lengths) { int kInputLength = 2; const TShape &dlshape = (*in_shape)[kInputLength]; - CHECK_EQ(dlshape.ndim(), 1U) << "Input length array must be a vector."; + CHECK_EQ(dlshape.ndim(), 1U) << "Data length array must be a vector."; CHECK_EQ(dlshape[0], dshape[1]) - << "The batch size for the inputs and input lengths must be the same."; + << "The batch size for the data and data lengths must be the same."; } if (param_.use_label_lengths) { - int kLabelLength = 2+param_.use_input_lengths; + int kLabelLength = 2+param_.use_data_lengths; const TShape &llshape = (*in_shape)[kLabelLength]; CHECK_EQ(llshape.ndim(), 1U) << "Label length array must be a vector."; CHECK_EQ(llshape[0], lshape[0]) @@ -514,7 +519,7 @@ class CTCLossProp : public OperatorProperty { CHECK_GE(dshape[0], lshape[1]) << "The max number of labels cannot exceed " "the maximum sequence length of the " - "input."; + "data."; TShape oshape(1); oshape[0] = dshape[1]; // batch size diff --git a/src/operator/contrib/ctc_loss.cc b/src/operator/contrib/ctc_loss.cc index 317915996a33..d544a1fdec04 100644 --- a/src/operator/contrib/ctc_loss.cc +++ b/src/operator/contrib/ctc_loss.cc @@ -31,7 +31,7 @@ namespace mshadow { template ctcStatus_t compute_ctc_cost(const Tensor activations, DType *costs, DType *grads, int *labels, - int *label_lengths, int *input_lengths, + int *label_lengths, int *data_lengths, void *workspace, int train) { int minibatch = static_cast(activations.size(1)); int alphabet_size = static_cast(activations.size(2)); @@ -39,10 +39,10 @@ ctcStatus_t compute_ctc_cost(const Tensor activations, mxnet_warpctc::CpuCTC ctc(alphabet_size, minibatch, workspace, blank_label); if (train) return ctc.cost_and_grad(activations.dptr_, grads, costs, labels, - label_lengths, input_lengths); + label_lengths, data_lengths); else return ctc.score_forward(activations.dptr_, costs, labels, label_lengths, - input_lengths); + data_lengths); } } // namespace mshadow @@ -100,9 +100,9 @@ information. .add_argument("data", "NDArray-or-Symbol", "Input data to the ctc_loss op.") .add_argument("label", "NDArray-or-Symbol", "Ground-truth labels for the loss.") - .add_argument("input_lengths", "NDArray-or-Symbol", - "Lengths of input for each of the samples. Only required " - "when use_input_lengths is true.") + .add_argument("data_lengths", "NDArray-or-Symbol", + "Lengths of data for each of the samples. Only required " + "when use_data_lengths is true.") .add_argument("label_lengths", "NDArray-or-Symbol", "Lengths of labels for each of the samples. Only required " "when use_label_lengths is true.") diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index ab31422bc71b..341750c67aa3 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -182,15 +182,15 @@ def test_ctc_loss(): l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]])) mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) - loss = gluon.loss.CTCLoss(use_label_lengths=True) + loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,20,4)), mx.nd.array([[2,1,2,2],[3,2,2,2]]), None, mx.nd.array([2,3])) mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) - loss = gluon.loss.CTCLoss(use_input_lengths=True) + loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,-1,-1],[3,2,2,-1]]), mx.nd.array([20,20])) mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741])) - loss = gluon.loss.CTCLoss(use_input_lengths=True, use_label_lengths=True) + loss = gluon.loss.CTCLoss() l = loss(mx.nd.ones((2,25,4)), mx.nd.array([[2,1,3,3],[3,2,2,3]]), mx.nd.array([20,20]), mx.nd.array([2,3])) mx.test_utils.assert_almost_equal(l.asnumpy(), np.array([18.82820702, 16.50581741]))