Skip to content

Commit

Permalink
Snake act (NVIDIA#7736)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
nithinraok authored and pzelasko committed Jan 3, 2024
1 parent fb36f2e commit c34e1ad
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/tts/conf/audio_codec/audio_codec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ model:
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false
log_wandb: true

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
Expand Down
4 changes: 2 additions & 2 deletions examples/tts/conf/audio_codec/encodec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ model:
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false
log_wandb: true

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
Expand Down
28 changes: 27 additions & 1 deletion nemo/collections/asr/parts/utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

__all__ = ['Swish']
__all__ = ['Swish', 'Snake']


@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""
equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2
"""
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x


class Snake(nn.Module):
"""
Snake activation function introduced in 'https://arxiv.org/abs/2006.08195'
"""

def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)


class Swish(nn.SiLU):
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch.nn as nn

from nemo.collections.asr.parts.utils.activations import Snake
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
Expand All @@ -42,6 +43,25 @@ def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]:
return padding, output_padding


class CodecActivation(nn.Module):
"""
Choose between snake or Elu activation based on the input parameter.
"""

def __init__(self, activation: str = "elu", channels: int = 1):
super().__init__()
activation = activation.lower()
if activation == "snake":
self.activation = Snake(channels)
elif activation == "elu":
self.activation = nn.ELU()
else:
raise ValueError(f"Unknown activation {activation}")

def forward(self, x):
return self.activation(x)


class Conv1dNorm(NeuralModule):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None
Expand Down
36 changes: 22 additions & 14 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from nemo.collections.tts.losses.audio_codec_loss import MaskedMSELoss
from nemo.collections.tts.modules.audio_codec_modules import (
CodecActivation,
Conv1dNorm,
Conv2dNorm,
ConvTranspose1dNorm,
Expand All @@ -61,12 +62,13 @@


class SEANetResnetBlock(NeuralModule):
def __init__(self, channels: int):
def __init__(self, channels: int, activation: str = "elu"):
super().__init__()
self.activation = nn.ELU()
self.pre_activation = CodecActivation(activation=activation, channels=channels)
hidden_channels = channels // 2
self.pre_conv = Conv1dNorm(in_channels=channels, out_channels=channels, kernel_size=1)
self.res_conv1 = Conv1dNorm(in_channels=channels, out_channels=hidden_channels, kernel_size=3)
self.post_activation = CodecActivation(activation=activation, channels=hidden_channels)
self.res_conv2 = Conv1dNorm(in_channels=hidden_channels, out_channels=channels, kernel_size=1)

@property
Expand All @@ -89,9 +91,9 @@ def remove_weight_norm(self):

@typecheck()
def forward(self, inputs, input_len):
res = self.activation(inputs)
res = self.pre_activation(inputs)
res = self.res_conv1(inputs=res, input_len=input_len)
res = self.activation(res)
res = self.post_activation(res)
res = self.res_conv2(inputs=res, input_len=input_len)

out = self.pre_conv(inputs=inputs, input_len=input_len) + res
Expand Down Expand Up @@ -148,6 +150,7 @@ def __init__(
in_kernel_size: int = 7,
out_kernel_size: int = 7,
encoded_dim: int = 128,
activation: str = "elu",
rnn_layers: int = 2,
rnn_type: str = "lstm",
rnn_skip: bool = True,
Expand All @@ -158,15 +161,16 @@ def __init__(
super().__init__()

self.down_sample_rates = down_sample_rates
self.activation = nn.ELU()
self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size)

in_channels = base_channels
self.res_blocks = nn.ModuleList([])
self.down_sample_conv_layers = nn.ModuleList([])
self.activations = nn.ModuleList([])
for i, down_sample_rate in enumerate(self.down_sample_rates):
res_block = SEANetResnetBlock(channels=in_channels)
self.res_blocks.append(res_block)
self.activations.append(CodecActivation(activation=activation, channels=in_channels))

out_channels = 2 * in_channels
kernel_size = 2 * down_sample_rate
Expand All @@ -180,6 +184,7 @@ def __init__(
in_channels = out_channels
self.down_sample_conv_layers.append(down_sample_conv)

self.post_activation = CodecActivation(activation=activation, channels=in_channels)
self.rnn = SEANetRNN(dim=in_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip)
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size)

Expand Down Expand Up @@ -211,19 +216,19 @@ def forward(self, audio, audio_len):
audio = rearrange(audio, "B T -> B 1 T")
# [B, C, T_audio]
out = self.pre_conv(inputs=audio, input_len=encoded_len)
for res_block, down_sample_conv, down_sample_rate in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates
for res_block, down_sample_conv, down_sample_rate, activation in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates, self.activations
):
# [B, C, T]
out = res_block(inputs=out, input_len=encoded_len)
out = self.activation(out)
out = activation(out)

encoded_len = encoded_len // down_sample_rate
# [B, 2 * C, T / down_sample_rate]
out = down_sample_conv(inputs=out, input_len=encoded_len)

out = self.rnn(inputs=out, input_len=encoded_len)
out = self.activation(out)
out = self.post_activation(out)
# [B, encoded_dim, T_encoded]
encoded = self.post_conv(inputs=out, input_len=encoded_len)
return encoded, encoded_len
Expand All @@ -237,6 +242,7 @@ def __init__(
in_kernel_size: int = 7,
out_kernel_size: int = 3,
encoded_dim: int = 128,
activation: str = "elu",
rnn_layers: int = 2,
rnn_type: str = "lstm",
rnn_skip: bool = True,
Expand All @@ -247,14 +253,15 @@ def __init__(
super().__init__()

self.up_sample_rates = up_sample_rates
self.activation = nn.ELU()
self.pre_conv = Conv1dNorm(in_channels=encoded_dim, out_channels=base_channels, kernel_size=in_kernel_size)
self.rnn = SEANetRNN(dim=base_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip)

in_channels = base_channels
self.res_blocks = nn.ModuleList([])
self.up_sample_conv_layers = nn.ModuleList([])
self.activations = nn.ModuleList([])
for i, up_sample_rate in enumerate(self.up_sample_rates):
self.activations.append(CodecActivation(activation=activation, channels=in_channels))
out_channels = in_channels // 2
kernel_size = 2 * up_sample_rate
up_sample_conv = ConvTranspose1dNorm(
Expand All @@ -266,6 +273,7 @@ def __init__(
res_block = SEANetResnetBlock(channels=in_channels)
self.res_blocks.append(res_block)

self.post_activation = CodecActivation(activation=activation, channels=in_channels)
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size)
self.out_activation = nn.Tanh()

Expand Down Expand Up @@ -296,16 +304,16 @@ def forward(self, inputs, input_len):
# [B, C, T_encoded]
out = self.pre_conv(inputs=inputs, input_len=audio_len)
out = self.rnn(inputs=out, input_len=audio_len)
for res_block, up_sample_conv, up_sample_rate in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates
for res_block, up_sample_conv, up_sample_rate, activation in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates, self.activations
):
audio_len = audio_len * up_sample_rate
out = self.activation(out)
out = activation(out)
# [B, C / 2, T * up_sample_rate]
out = up_sample_conv(inputs=out, input_len=audio_len)
out = res_block(inputs=out, input_len=audio_len)

out = self.activation(out)
out = self.post_activation(out)
# [B, 1, T_audio]
out = self.post_conv(inputs=out, input_len=audio_len)
audio = self.out_activation(out)
Expand Down
25 changes: 24 additions & 1 deletion tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import pytest
import torch

from nemo.collections.tts.modules.audio_codec_modules import Conv1dNorm, ConvTranspose1dNorm, get_down_sample_padding
from nemo.collections.tts.modules.audio_codec_modules import (
CodecActivation,
Conv1dNorm,
ConvTranspose1dNorm,
get_down_sample_padding,
)
from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer


Expand Down Expand Up @@ -182,3 +187,21 @@ def test_group_rvq_eval(self, num_groups: int, num_codebooks: int):
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)


class TestCodecActivation:
def setup_class(self):
self.batch_size = 2
self.in_channels = 4
self.max_len = 4

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_snake(self):
"""
Test for snake activation function execution.
"""
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
snake = CodecActivation('snake', channels=self.in_channels)
out = snake(x=inputs)
assert out.shape == (self.batch_size, self.in_channels, self.max_len)

0 comments on commit c34e1ad

Please sign in to comment.