Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snake act #7736

Merged
merged 39 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9a8e089
replace elu to snake
Sep 19, 2023
4ceb9c0
adding torch tensor
Sep 19, 2023
031a94f
add wandb recons logging
Sep 20, 2023
58883dd
support typecheck
Oct 4, 2023
31ad191
replace elu to snake
Sep 19, 2023
4c08299
adding torch tensor
Sep 19, 2023
9e6b8f5
choose between snake or elu activation
Oct 13, 2023
713c5b8
update based on comments
Oct 17, 2023
3558529
update yaml
Oct 17, 2023
d7b6a42
Merge branch 'main' into snake_act
nithinraok Oct 17, 2023
5d809df
Merge branch 'main' into snake_act
nithinraok Oct 18, 2023
1be02ff
replace elu to snake
Sep 19, 2023
0a567bb
adding torch tensor
Sep 19, 2023
41aa947
add wandb recons logging
Sep 20, 2023
2c0f070
support typecheck
Oct 4, 2023
a7519f1
replace elu to snake
Sep 19, 2023
a7589c8
adding torch tensor
Sep 19, 2023
1b75989
choose between snake or elu activation
Oct 13, 2023
09a1c66
update based on comments
Oct 17, 2023
147b8f3
update yaml
Oct 17, 2023
dded2a7
update
Oct 19, 2023
2baeea8
replace elu to snake
Sep 19, 2023
d13063f
adding torch tensor
Sep 19, 2023
22cf0cd
add wandb recons logging
Sep 20, 2023
6c15ab0
support typecheck
Oct 4, 2023
94653ca
replace elu to snake
Sep 19, 2023
bb86120
adding torch tensor
Sep 19, 2023
b5b8152
choose between snake or elu activation
Oct 13, 2023
b0a61af
update based on comments
Oct 17, 2023
ced73e6
update yaml
Oct 17, 2023
aeb2043
replace elu to snake
Sep 19, 2023
5b7f7e0
adding torch tensor
Sep 19, 2023
8c30f8e
replace elu to snake
Sep 19, 2023
2e65b1c
adding torch tensor
Sep 19, 2023
562398b
choose between snake or elu activation
Oct 13, 2023
a3fd702
update based on comments
Oct 17, 2023
663db12
update yaml
Oct 17, 2023
8942c5a
merge
Oct 19, 2023
2cba872
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/tts/conf/audio_codec/audio_codec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ model:
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false
log_wandb: true

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
Expand Down
4 changes: 2 additions & 2 deletions examples/tts/conf/audio_codec/encodec_24000.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ model:
# Remove this section to disable logging.
log_config:
log_dir: ${log_dir}
log_epochs: [10, 50]
log_epochs: [10, 50, 100, 150, 200]
epoch_frequency: 100
log_tensorboard: false
log_wandb: false
log_wandb: true
anteju marked this conversation as resolved.
Show resolved Hide resolved

generators:
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
Expand Down
28 changes: 27 additions & 1 deletion nemo/collections/asr/parts/utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

__all__ = ['Swish']
__all__ = ['Swish', 'Snake']


@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""
equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2
"""
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x


class Snake(nn.Module):
"""
Snake activation function introduced in 'https://arxiv.org/abs/2006.08195'
"""

def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)


class Swish(nn.SiLU):
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch.nn as nn

from nemo.collections.asr.parts.utils.activations import Snake
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
Expand All @@ -42,6 +43,25 @@ def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]:
return padding, output_padding


class CodecActivation(nn.Module):
"""
Choose between snake or Elu activation based on the input parameter.
"""

def __init__(self, activation: str = "elu", channels: int = 1):
super().__init__()
activation = activation.lower()
if activation == "snake":
self.activation = Snake(channels)
elif activation == "elu":
self.activation = nn.ELU()
else:
raise ValueError(f"Unknown activation {activation}")

def forward(self, x):
return self.activation(x)


class Conv1dNorm(NeuralModule):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None
Expand Down
36 changes: 22 additions & 14 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from nemo.collections.tts.losses.audio_codec_loss import MaskedMSELoss
from nemo.collections.tts.modules.audio_codec_modules import (
CodecActivation,
Conv1dNorm,
Conv2dNorm,
ConvTranspose1dNorm,
Expand All @@ -61,12 +62,13 @@


class SEANetResnetBlock(NeuralModule):
def __init__(self, channels: int):
def __init__(self, channels: int, activation: str = "elu"):
super().__init__()
self.activation = nn.ELU()
self.pre_activation = CodecActivation(activation=activation, channels=channels)
hidden_channels = channels // 2
self.pre_conv = Conv1dNorm(in_channels=channels, out_channels=channels, kernel_size=1)
self.res_conv1 = Conv1dNorm(in_channels=channels, out_channels=hidden_channels, kernel_size=3)
self.post_activation = CodecActivation(activation=activation, channels=hidden_channels)
self.res_conv2 = Conv1dNorm(in_channels=hidden_channels, out_channels=channels, kernel_size=1)

@property
Expand All @@ -89,9 +91,9 @@ def remove_weight_norm(self):

@typecheck()
def forward(self, inputs, input_len):
res = self.activation(inputs)
res = self.pre_activation(inputs)
res = self.res_conv1(inputs=res, input_len=input_len)
res = self.activation(res)
res = self.post_activation(res)
res = self.res_conv2(inputs=res, input_len=input_len)

out = self.pre_conv(inputs=inputs, input_len=input_len) + res
Expand Down Expand Up @@ -148,6 +150,7 @@ def __init__(
in_kernel_size: int = 7,
out_kernel_size: int = 7,
encoded_dim: int = 128,
activation: str = "elu",
rnn_layers: int = 2,
rnn_type: str = "lstm",
rnn_skip: bool = True,
Expand All @@ -158,15 +161,16 @@ def __init__(
super().__init__()

self.down_sample_rates = down_sample_rates
self.activation = nn.ELU()
self.pre_conv = Conv1dNorm(in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size)

in_channels = base_channels
self.res_blocks = nn.ModuleList([])
self.down_sample_conv_layers = nn.ModuleList([])
self.activations = nn.ModuleList([])
for i, down_sample_rate in enumerate(self.down_sample_rates):
res_block = SEANetResnetBlock(channels=in_channels)
self.res_blocks.append(res_block)
self.activations.append(CodecActivation(activation=activation, channels=in_channels))

out_channels = 2 * in_channels
kernel_size = 2 * down_sample_rate
Expand All @@ -180,6 +184,7 @@ def __init__(
in_channels = out_channels
self.down_sample_conv_layers.append(down_sample_conv)

self.post_activation = CodecActivation(activation=activation, channels=in_channels)
self.rnn = SEANetRNN(dim=in_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip)
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=encoded_dim, kernel_size=out_kernel_size)

Expand Down Expand Up @@ -211,19 +216,19 @@ def forward(self, audio, audio_len):
audio = rearrange(audio, "B T -> B 1 T")
# [B, C, T_audio]
out = self.pre_conv(inputs=audio, input_len=encoded_len)
for res_block, down_sample_conv, down_sample_rate in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates
for res_block, down_sample_conv, down_sample_rate, activation in zip(
self.res_blocks, self.down_sample_conv_layers, self.down_sample_rates, self.activations
):
# [B, C, T]
out = res_block(inputs=out, input_len=encoded_len)
out = self.activation(out)
out = activation(out)

encoded_len = encoded_len // down_sample_rate
# [B, 2 * C, T / down_sample_rate]
out = down_sample_conv(inputs=out, input_len=encoded_len)

out = self.rnn(inputs=out, input_len=encoded_len)
out = self.activation(out)
out = self.post_activation(out)
# [B, encoded_dim, T_encoded]
encoded = self.post_conv(inputs=out, input_len=encoded_len)
return encoded, encoded_len
Expand All @@ -237,6 +242,7 @@ def __init__(
in_kernel_size: int = 7,
out_kernel_size: int = 3,
encoded_dim: int = 128,
activation: str = "elu",
rnn_layers: int = 2,
rnn_type: str = "lstm",
rnn_skip: bool = True,
Expand All @@ -247,14 +253,15 @@ def __init__(
super().__init__()

self.up_sample_rates = up_sample_rates
self.activation = nn.ELU()
self.pre_conv = Conv1dNorm(in_channels=encoded_dim, out_channels=base_channels, kernel_size=in_kernel_size)
self.rnn = SEANetRNN(dim=base_channels, num_layers=rnn_layers, rnn_type=rnn_type, use_skip=rnn_skip)

in_channels = base_channels
self.res_blocks = nn.ModuleList([])
self.up_sample_conv_layers = nn.ModuleList([])
self.activations = nn.ModuleList([])
for i, up_sample_rate in enumerate(self.up_sample_rates):
self.activations.append(CodecActivation(activation=activation, channels=in_channels))
out_channels = in_channels // 2
kernel_size = 2 * up_sample_rate
up_sample_conv = ConvTranspose1dNorm(
Expand All @@ -266,6 +273,7 @@ def __init__(
res_block = SEANetResnetBlock(channels=in_channels)
self.res_blocks.append(res_block)

self.post_activation = CodecActivation(activation=activation, channels=in_channels)
self.post_conv = Conv1dNorm(in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size)
self.out_activation = nn.Tanh()

Expand Down Expand Up @@ -296,16 +304,16 @@ def forward(self, inputs, input_len):
# [B, C, T_encoded]
out = self.pre_conv(inputs=inputs, input_len=audio_len)
out = self.rnn(inputs=out, input_len=audio_len)
for res_block, up_sample_conv, up_sample_rate in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates
for res_block, up_sample_conv, up_sample_rate, activation in zip(
self.res_blocks, self.up_sample_conv_layers, self.up_sample_rates, self.activations
):
audio_len = audio_len * up_sample_rate
out = self.activation(out)
out = activation(out)
# [B, C / 2, T * up_sample_rate]
out = up_sample_conv(inputs=out, input_len=audio_len)
out = res_block(inputs=out, input_len=audio_len)

out = self.activation(out)
out = self.post_activation(out)
# [B, 1, T_audio]
out = self.post_conv(inputs=out, input_len=audio_len)
audio = self.out_activation(out)
Expand Down
25 changes: 24 additions & 1 deletion tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import pytest
import torch

from nemo.collections.tts.modules.audio_codec_modules import Conv1dNorm, ConvTranspose1dNorm, get_down_sample_padding
from nemo.collections.tts.modules.audio_codec_modules import (
CodecActivation,
Conv1dNorm,
ConvTranspose1dNorm,
get_down_sample_padding,
)
from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer


Expand Down Expand Up @@ -182,3 +187,21 @@ def test_group_rvq_eval(self, num_groups: int, num_codebooks: int):
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)


class TestCodecActivation:
def setup_class(self):
self.batch_size = 2
self.in_channels = 4
self.max_len = 4

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_snake(self):
"""
Test for snake activation function execution.
"""
inputs = torch.rand([self.batch_size, self.in_channels, self.max_len])
snake = CodecActivation('snake', channels=self.in_channels)
out = snake(x=inputs)
assert out.shape == (self.batch_size, self.in_channels, self.max_len)
Loading