diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 92218e807aac..f1997c385020 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -493,7 +493,7 @@ def __init__(self, hidden_dim, condition_dim, condition_types=[]): def forward(self, inputs, conditioning=None): """ Args: - inputs (torch.tensor): B x T x C tensor. + inputs (torch.tensor): B x T x H tensor. conditioning (torch.tensor): B x 1 x C conditioning embedding. """ if len(self.condition_types) > 0: @@ -510,7 +510,7 @@ def forward(self, inputs, conditioning=None): if "concat" in self.condition_types: conditioning = conditioning.repeat(1, inputs.shape[1], 1) - inputs = torch.cat([inputs, conditioning]) + inputs = torch.cat([inputs, conditioning], dim=-1) inputs = self.concat_proj(inputs) return inputs