Skip to content

Commit

Permalink
contrib ctc interface changes, cudnn7 CTC, and gluon CTC (apache#7442)
Browse files Browse the repository at this point in the history
* contrib ctc interface changes for compatibility

* cudnn ctc

* update per comments
  • Loading branch information
szha authored and crazy-cat committed Oct 26, 2017
1 parent 3b9aa70 commit 75a07d1
Show file tree
Hide file tree
Showing 6 changed files with 430 additions and 52 deletions.
90 changes: 90 additions & 0 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from __future__ import absolute_import

from .. import ndarray
from ..contrib import symbol as symbol_contrib
from ..contrib import ndarray as ndarray_contrib
from ..base import numeric_types
from .block import HybridBlock

Expand Down Expand Up @@ -295,3 +297,91 @@ def hybrid_forward(self, F, output, label, sample_weight=None):
loss = label * (F.log(label+1e-8) - output)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


class CTCLoss(Loss):
r"""Connectionist Temporal Classification Loss.
See `"Connectionist Temporal Classification: Labelling Unsegmented
Sequence Data with Recurrent Neural Networks"
<http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ paper for more information.
Parameters
----------
layout : str, default 'NTC'
Layout of the output sequence activation vector.
label_layout : str, default 'NT'
Layout of the labels.
padding_mask : int or None, default -1
This is the label value to be considered padding, which is used to derive the actual
lengths of labels. Only required when `label_lengths` is None.
weight : float or None
Global scalar weight for loss.
sample_weight : Symbol or None
Per sample weighting. Must be broadcastable to
the same shape as loss. For example, if loss has
shape (64, 10) and you want to weight each sample
in the batch, `sample_weight` should have shape (64, 1).
This should be used as the fifth argument when calling this loss.
Input shapes:
`data` is an activation tensor without softmax.
Its shape depends on `layout`. For `layout='TNC'`, this
input has shape `(sequence_length, batch_size, alphabet_size)`
`label` is the label index matrix.
Its shape depends on `label_layout`. For `label_layout='TN'`, this
input has shape `(label_sequence_length, batch_size)`
When `label_lengths` is not specified, the first occurrence of `padding_mask`
in each sample marks the end of the label sequence of that sample.
For example, suppose there are two samples, with *label_sequence_length* = 4.
The two sequences of labels are [2, 1] and [3, 2, 2], and their actual lengths
are smaller than 4. Thus, given *padding_mask* = 0, the resulting ```label```
tensor should be padded to be::
[[2, 1, 0, 0], [3, 2, 2, 0]]
`data_lengths` is optional and defaults to None.
When specified, it represents the actual lengths of data.
The shape should be (batch_size,).
If None, the data lengths are treated as being equal to the max sequence length.
This should be used as the third argument when calling this loss.
`label_lengths` is optional and defaults to None.
When specified, it represents the actual lengths of labels.
The shape should be (batch_size,).
If None, the label lengths are derived from the first occurrence of
the value specified by `padding_mask`.
This should be used as the fourth argument when calling this loss.
Output shape:
The CTC loss output has the shape (batch_size,).
"""
def __init__(self, layout='NTC', label_layout='NT', padding_mask=-1,
weight=None, **kwargs):
assert layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for output are supported. Got: %s"%layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
self._layout = layout
self._label_layout = label_layout
self._padding_mask = padding_mask
batch_axis = label_layout.find('N')
super(CTCLoss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, data, label,
data_lengths=None, label_lengths=None, sample_weight=None):
if self._layout == 'NTC':
data = F.swapaxes(data, 0, 1)
if self._batch_axis == 1:
label = F.swapaxes(label, 0, 1)
if F is ndarray:
F_contrib = ndarray_contrib
else:
F_contrib = symbol_contrib
loss = F_contrib.CTCLoss(data, label,
use_data_lengths=data_lengths is not None,
use_label_lengths=label_lengths is not None,
data_lengths=data_lengths, label_lengths=label_lengths,
padding_mask=self._padding_mask)
return _apply_weighting(F, loss, self._weight, sample_weight)
Loading

0 comments on commit 75a07d1

Please sign in to comment.