Skip to content

Commit

Permalink
Fix adding positional embeddings in-place in transformer module (NVID…
Browse files Browse the repository at this point in the history
…IA#7440)

Signed-off-by: Tamerlan Tabolov <[email protected]>
Co-authored-by: Cheng-Ping Hsieh <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
2 people authored and ssh-meister committed Oct 2, 2023
1 parent f925e8d commit 3770f84
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def forward(self, input, seq_lens, conditioning=None):
def _forward(self, inp, mask, conditioning):
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
pos_emb = self.pos_emb(pos_seq) * mask
inp += pos_emb
inp = inp + pos_emb
inp = self.cond_input(inp, conditioning)
out = self.drop(inp)

Expand Down

0 comments on commit 3770f84

Please sign in to comment.