Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Add Smooth-Regularized CTC #1769

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 130 additions & 18 deletions egs/librispeech/ASR/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import List, Optional, Tuple

import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

from icefall.utils import add_sos, make_pad_mask
from icefall.utils import add_sos, make_pad_mask, time_warp
from lhotse.dataset import SpecAugment


class AsrModel(nn.Module):
Expand Down Expand Up @@ -110,9 +112,8 @@ def __init__(
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Dropout(p=0.1), # TODO: test removing this
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)

self.use_attention_decoder = use_attention_decoder
Expand Down Expand Up @@ -158,28 +159,82 @@ def forward_ctc(
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
use_consistency_reg: bool = False,
use_smooth_reg: bool = False,
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
Encoder output, of shape (N or 2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
Encoder output lengths, of shape (N or 2 * N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
use_consistency_reg:
Whether use consistency regularization.
use_smooth_reg:
Whether use smooth regularization.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)

if not use_smooth_reg:
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
else:
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
ctc_log_probs = (ctc_probs + eps).log()

# Compute CTC loss
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss

if use_consistency_reg:
assert ctc_log_probs.shape[0] % 2 == 0
# Compute cr_loss
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_log_probs,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
else:
cr_loss = torch.empty(0)

if use_smooth_reg:
# Hard code the kernel here, could try other values
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
device=ctc_probs.device, requires_grad=False)
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
# Now kernel: (C, 1, 3)
smoothed_ctc_probs = F.conv1d(
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
sr_loss = nn.functional.kl_div(
input=ctc_log_probs[:, 1:-1],
target=(smoothed_ctc_probs + eps).log(),
reduction="none",
log_target=True,
) # (N, T - 1 , C)
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
else:
sr_loss = torch.empty(0)

return ctc_loss, cr_loss, sr_loss

def forward_transducer(
self,
Expand Down Expand Up @@ -296,7 +351,13 @@ def forward(
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
use_cr_ctc: bool = False,
use_sr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
Expand All @@ -316,9 +377,28 @@ def forward(
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_sr_ctc:
Whether use smooth-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.

Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
Return the transducer losses, CTC loss, AED loss,
and consistency-regularization loss in form of
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)

Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
Expand All @@ -334,6 +414,24 @@ def forward(

device = x.device

if use_cr_ctc:
assert self.use_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)

# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

Expand All @@ -351,21 +449,33 @@ def forward(
am_scale=am_scale,
lm_scale=lm_scale,
)
if use_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)

if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
ctc_loss, cr_loss, sr_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
use_consistency_reg=use_cr_ctc,
use_smooth_reg=use_sr_ctc,
)
if use_cr_ctc:
# We duplicate the batch when use_cr_ctc is True
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
sr_loss = sr_loss * 0.5
else:
ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
sr_loss = torch.empty(0)

if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss(
Expand All @@ -374,7 +484,9 @@ def forward(
ys=y.to(device),
ys_lens=y_lens.to(device),
)
if use_cr_ctc:
attention_decoder_loss = attention_decoder_loss * 0.5
else:
attention_decoder_loss = torch.empty(0)

return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss
Loading
Loading