Skip to content

Commit

Permalink
Group-residual vector quantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Oct 6, 2023
1 parent 983e9e1 commit 924b3b1
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
141 changes: 141 additions & 0 deletions nemo/collections/tts/modules/encodec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging
from nemo.utils.decorators import experimental


Expand Down Expand Up @@ -807,3 +808,143 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
dequantized = dequantized + dequantized_i
dequantized = rearrange(dequantized, "B T D -> B D T")
return dequantized


class GroupResidualVectorQuantizer(NeuralModule):
"""Split the input vector into groups and apply RVQ on each group separately.
Args:
num_codebooks: total number of codebooks
num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks
codebook_dim: embedding dimension, will be split into num_groups
**kwargs: parameters of ResidualVectorQuantizer
References:
Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765).
"""

def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwargs):
super().__init__()

self.num_codebooks = num_codebooks
self.num_groups = num_groups
self.codebook_dim = codebook_dim

# Initialize RVQ for each group
self.rvqs = torch.nn.ModuleList(
[
ResidualVectorQuantizer(
num_codebooks=self.num_codebooks_per_group, codebook_dim=self.codebook_dim_per_group, **kwargs
)
for _ in range(self.num_groups)
]
)

logging.debug('Initialized {self.__class__.__name__} with')
logging.debug('\tnum_codebooks: %d', self.num_codebooks)
logging.debug('\tnum_groups: %d', self.num_groups)
logging.debug('\tcodebook_dim: %d', self.codebook_dim)
logging.debug('\tnum_codebooks_per_group: %d', self.num_codebooks_per_group)
logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group)

@property
def num_codebooks_per_group(self):
"""Number of codebooks for each group.
"""
if self.num_codebooks % self.num_groups != 0:
raise ValueError(
f'num_codebooks ({self.num_codebooks}) must be divisible by num_groups ({self.num_groups})'
)

return self.num_codebooks // self.num_groups

@property
def codebook_dim_per_group(self):
"""Input vector dimension for each group.
"""
if self.codebook_dim % self.num_groups != 0:
raise ValueError(f'codebook_dim ({self.codebook_dim}) must be divisible by num_groups ({self.num_groups})')

return self.codebook_dim // self.num_groups

@property
def input_types(self):
return {
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
return {
"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"indices": NeuralType(('D', 'B', 'T'), Index()),
"commit_loss": NeuralType((), LossType()),
}

@typecheck()
def forward(self, inputs, input_len):
"""Quantize each group separately, then concatenate the results.
"""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)

dequantized, indices = [], []
commit_loss = 0

for in_group, rvq_group in zip(inputs_grouped, self.rvqs):
dequantized_group, indices_group, commit_loss_group = rvq_group(inputs=in_group, input_len=input_len)
dequantized.append(dequantized_group)
indices.append(indices_group)
commit_loss += commit_loss_group

# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)

# concatente along the codebook dimension
indices = torch.cat(indices, dim=0)

return dequantized, indices, commit_loss

@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"indices": NeuralType(('D', 'B', 'T'), Index())},
)
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor:
"""Input is split into groups, each group is encoded separately, then the results are concatenated.
"""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
indices = []

for in_group, rvq_group in zip(inputs_grouped, self.rvqs):
indices_group = rvq_group.encode(inputs=in_group, input_len=input_len)
indices.append(indices_group)

# concatenate along the codebook dimension
indices = torch.cat(indices, dim=0)

return indices

@typecheck(
input_types={
"indices": NeuralType(('D', 'B', 'T'), Index()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),},
)
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
"""Input indices are split into groups, each group is decoded separately, then the results are concatenated.
"""
indices_grouped = indices.chunk(self.num_groups, dim=0)
dequantized = []

for indices_group, rvq_group in zip(indices_grouped, self.rvqs):
dequantized_group = rvq_group.decode(indices=indices_group, input_len=input_len)
dequantized.append(dequantized_group)

# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)

return dequantized
93 changes: 93 additions & 0 deletions tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from nemo.collections.tts.modules.audio_codec_modules import Conv1dNorm, ConvTranspose1dNorm, get_down_sample_padding
from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer


class TestAudioCodecModules:
Expand Down Expand Up @@ -89,3 +90,95 @@ def test_conv1d_transpose_upsample(self):
assert torch.all(out[0, :, out_len_1:] == 0.0)
assert torch.all(out[1, :, :out_len_2] != 0.0)
assert torch.all(out[1, :, out_len_2:] == 0.0)


class TestResidualVectorQuantizer:
def setup_class(self):
"""Setup common members
"""
self.batch_size = 2
self.max_len = 20
self.codebook_size = 256
self.codebook_dim = 64
self.num_examples = 10

@pytest.mark.unit
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_rvq_eval(self, num_codebooks: int):
"""Simple test to confirm that the RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
# instantiate and set in eval mode
rvq = ResidualVectorQuantizer(num_codebooks=num_codebooks, codebook_dim=self.codebook_dim)
rvq.eval()

for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)

# apply forward
dequantized_fw, indices_fw, commit_loss = rvq(inputs=inputs, input_len=input_len)

# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'

# encode-decode
indices_enc = rvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = rvq.decode(indices=indices_enc, input_len=input_len)

# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')

@pytest.mark.unit
@pytest.mark.parametrize('num_groups', [1, 2, 4])
@pytest.mark.parametrize('num_codebooks', [1, 4])
def test_group_rvq_eval(self, num_groups: int, num_codebooks: int):
"""Simple test to confirm that the group RVQ module can be instantiated and run,
and that forward produces the same result as encode-decode.
"""
if num_groups > num_codebooks:
# Expected to fail if num_groups is lager than the total number of codebooks
with pytest.raises(ValueError):
_ = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
else:
# Test inference with group RVQ
# instantiate and set in eval mode
grvq = GroupResidualVectorQuantizer(
num_codebooks=num_codebooks, num_groups=num_groups, codebook_dim=self.codebook_dim
)
grvq.eval()

for i in range(self.num_examples):
inputs = torch.randn([self.batch_size, self.codebook_dim, self.max_len])
input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32)

# apply forward
dequantized_fw, indices_fw, commit_loss = grvq(inputs=inputs, input_len=input_len)

# make sure the commit loss is zero
assert commit_loss == 0.0, f'example {i}: commit_loss is {commit_loss}, expected 0.0'

# encode-decode
indices_enc = grvq.encode(inputs=inputs, input_len=input_len)
dequantized_dec = grvq.decode(indices=indices_enc, input_len=input_len)

# make sure the results are the same
torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch')
torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch')

# apply individual RVQs and make sure the results are the same
inputs_grouped = inputs.chunk(num_groups, dim=1)
dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1)
indices_fw_grouped = indices_fw.chunk(num_groups, dim=0)

for g in range(num_groups):
dequantized, indices, _ = grvq.rvqs[g](inputs=inputs_grouped[g], input_len=input_len)
torch.testing.assert_close(
dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}'
)
torch.testing.assert_close(
indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}'
)

0 comments on commit 924b3b1

Please sign in to comment.