From 5ba0ae9ca9d3eb9f0268ab511cee8637743eb820 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Fri, 19 Jan 2024 18:08:53 -0800 Subject: [PATCH] Added VectorQuantizer base class (#8011) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić Signed-off-by: Pablo Garay --- .../tts/modules/audio_codec_modules.py | 78 +++++++++++-------- .../tts/modules/encodec_modules.py | 21 ++--- 2 files changed, 52 insertions(+), 47 deletions(-) diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 5773462a9310..933d09a67fa1 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import Iterable, List, Optional, Tuple import torch @@ -340,7 +341,50 @@ def forward(self, audio_real, audio_gen): return scores_real, scores_gen, fmaps_real, fmaps_gen -class FiniteScalarQuantizer(NeuralModule): +class VectorQuantizerBase(NeuralModule, ABC): + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('D', 'B', 'T'), Index()), + } + + @typecheck() + @abstractmethod + def forward(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + ) + @abstractmethod + def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + pass + + @typecheck( + input_types={ + "indices": NeuralType(('D', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + @abstractmethod + def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + pass + + +class FiniteScalarQuantizer(VectorQuantizerBase): """This quantizer is based on the Finite Scalar Quantization (FSQ) method. It quantizes each element of the input vector independently into a number of levels. @@ -478,21 +522,7 @@ def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: indices = torch.sum(indices * self.dim_base_index, dim=1) return indices.to(torch.int32) - # API of the RVQ - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), - } - - @property - def output_types(self): - return { - "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "indices": NeuralType(('D', 'B', 'T'), Index()), - } - + # Implementation of VectorQuantiserBase API @typecheck() def forward( self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None @@ -556,7 +586,7 @@ def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None return dequantized -class GroupFiniteScalarQuantizer(NeuralModule): +class GroupFiniteScalarQuantizer(VectorQuantizerBase): """Split the input vector into groups and apply FSQ on each group separately. This class is for convenience. Since FSQ is applied on each group separately, groups can be defined arbitrarily by splitting the input vector. However, this @@ -604,20 +634,6 @@ def codebook_size(self): """Returns the size of the implicit codebook.""" return self.codebook_size_per_group ** self.num_groups - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "input_len": NeuralType(tuple('B'), LengthsType()), - } - - @property - def output_types(self): - return { - "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "indices": NeuralType(('D', 'B', 'T'), Index()), - } - @typecheck() def forward(self, inputs, input_len): """Quantize each group separately, then concatenate the results. diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index 1719dcc15179..e93c7c799550 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -49,6 +49,7 @@ Conv1dNorm, Conv2dNorm, ConvTranspose1dNorm, + VectorQuantizerBase, get_down_sample_padding, ) from nemo.collections.tts.parts.utils.distributed import broadcast_tensors @@ -690,7 +691,7 @@ def decode(self, indices, input_len): return dequantized -class ResidualVectorQuantizer(NeuralModule): +class ResidualVectorQuantizer(VectorQuantizerBase): """ Residual vector quantization (RVQ) algorithm as described in https://arxiv.org/pdf/2107.03312.pdf. @@ -732,13 +733,7 @@ def __init__( ] ) - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "input_len": NeuralType(tuple('B'), LengthsType()), - } - + # Override output types, since this quantizer returns commit_loss @property def output_types(self): return { @@ -818,7 +813,7 @@ def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: return dequantized -class GroupResidualVectorQuantizer(NeuralModule): +class GroupResidualVectorQuantizer(VectorQuantizerBase): """Split the input vector into groups and apply RVQ on each group separately. Args: @@ -875,13 +870,7 @@ def codebook_dim_per_group(self): return self.codebook_dim // self.num_groups - @property - def input_types(self): - return { - "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), - "input_len": NeuralType(tuple('B'), LengthsType()), - } - + # Override output types, since this quantizer returns commit_loss @property def output_types(self): return {