diff --git a/nemo/collections/tts/models/aligner.py b/nemo/collections/tts/models/aligner.py index 49301afc1591..9ca04ee74f18 100644 --- a/nemo/collections/tts/models/aligner.py +++ b/nemo/collections/tts/models/aligner.py @@ -126,7 +126,7 @@ def forward(self, *, spec, spec_len, text, text_len, attn_prior=None): attn_soft, attn_logprob = self.alignment_encoder( queries=spec, keys=self.embed(text).transpose(1, 2), - mask=get_mask_from_lengths(text_len).unsqueeze(-1), + mask=get_mask_from_lengths(text_len).unsqueeze(-1) == 0, attn_prior=attn_prior, )