Skip to content

Commit

Permalink
Fix bug in ConditionalInput: cat along the feature dim, not the batch…
Browse files Browse the repository at this point in the history
… dim (NVIDIA#7785)

Signed-off-by: anferico <[email protected]>
  • Loading branch information
anferico committed Oct 24, 2023
1 parent 214dcca commit bc61629
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def __init__(self, hidden_dim, condition_dim, condition_types=[]):
def forward(self, inputs, conditioning=None):
"""
Args:
inputs (torch.tensor): B x T x C tensor.
inputs (torch.tensor): B x T x H tensor.
conditioning (torch.tensor): B x 1 x C conditioning embedding.
"""
if len(self.condition_types) > 0:
Expand All @@ -510,7 +510,7 @@ def forward(self, inputs, conditioning=None):

if "concat" in self.condition_types:
conditioning = conditioning.repeat(1, inputs.shape[1], 1)
inputs = torch.cat([inputs, conditioning])
inputs = torch.cat([inputs, conditioning], dim=-1)
inputs = self.concat_proj(inputs)

return inputs
Expand Down

0 comments on commit bc61629

Please sign in to comment.