diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 2be43981a64c..a8194a639cb3 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -409,6 +409,8 @@ class CTCLoss(Loss): length respectively. weight : float or None Global scalar weight for loss. + blank_label : {'first', 'last'}, default 'last' + Set the label that is reserved for blank label. Inputs: @@ -452,13 +454,16 @@ class CTCLoss(Loss): Sequence Data with Recurrent Neural Networks `_ """ - def __init__(self, layout='NTC', label_layout='NT', weight=None, **kwargs): + def __init__(self, layout='NTC', label_layout='NT', weight=None, blank_label='last', **kwargs): assert layout in ['NTC', 'TNC'],\ "Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout assert label_layout in ['NT', 'TN'],\ "Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout + assert blank_label in ['first', 'last'],\ + "Only 'first' and 'last' are supported for blank_label. Got: %s"%blank_label self._layout = layout self._label_layout = label_layout + self._blank_label = blank_label batch_axis = label_layout.find('N') super(CTCLoss, self).__init__(weight, batch_axis, **kwargs) @@ -471,7 +476,7 @@ def hybrid_forward(self, F, pred, label, loss = F.contrib.CTCLoss(pred, label, pred_lengths, label_lengths, use_data_lengths=pred_lengths is not None, use_label_lengths=label_lengths is not None, - blank_label='last') + blank_label=self._blank_label) return _apply_weighting(F, loss, self._weight, sample_weight)