Skip to content

Commit

Permalink
add blank_label parameter for CTCLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
HawkAaron committed Jul 20, 2018
1 parent 050df05 commit aab11f7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ class CTCLoss(Loss):
length respectively.
weight : float or None
Global scalar weight for loss.
blank_label : {'first', 'last'}, default 'last'
Set the label that is reserved for blank label.
Inputs:
Expand Down Expand Up @@ -452,13 +454,16 @@ class CTCLoss(Loss):
Sequence Data with Recurrent Neural Networks
<http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_
"""
def __init__(self, layout='NTC', label_layout='NT', weight=None, **kwargs):
def __init__(self, layout='NTC', label_layout='NT', weight=None, blank_label='last', **kwargs):
assert layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
assert blank_label in ['first', 'last'],\
"Only 'first' and 'last' are supported for blank_label. Got: %s"%blank_label
self._layout = layout
self._label_layout = label_layout
self._blank_label = blank_label
batch_axis = label_layout.find('N')
super(CTCLoss, self).__init__(weight, batch_axis, **kwargs)

Expand All @@ -471,7 +476,7 @@ def hybrid_forward(self, F, pred, label,
loss = F.contrib.CTCLoss(pred, label, pred_lengths, label_lengths,
use_data_lengths=pred_lengths is not None,
use_label_lengths=label_lengths is not None,
blank_label='last')
blank_label=self._blank_label)
return _apply_weighting(F, loss, self._weight, sample_weight)


Expand Down

0 comments on commit aab11f7

Please sign in to comment.