From 1fd6431a590430a31ae228355edeed1197ec006b Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:35:43 -0800 Subject: [PATCH] [Codec] Finite scalar quantizer (#7886) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Finite scalar quantizer Signed-off-by: Ante Jukić * Updated test Signed-off-by: Ante Jukić --------- Signed-off-by: Ante Jukić --- nemo/collections/tts/models/audio_codec.py | 19 +- .../tts/modules/audio_codec_modules.py | 353 +++++++++++++++++- .../tts/modules/encodec_modules.py | 2 +- .../tts/modules/test_audio_codec_modules.py | 164 ++++++++ 4 files changed, 533 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 9b6675db5979..069a14c0eab3 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -77,6 +77,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if "vector_quantizer" in cfg: self.vector_quantizer = instantiate(cfg.vector_quantizer) + + vq_output_types = list(self.vector_quantizer.output_types.keys()) + + if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss': + self.vector_quantizer_has_commit_loss = True + logging.info('Vector quantizer supports commit loss.') + else: + self.vector_quantizer_has_commit_loss = False + logging.info('Vector quantizer does not support commit loss.') + else: logging.warning('Vector quantizer will not be used.') self.vector_quantizer = None @@ -124,6 +134,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): else: self.commit_loss_scale = 0.0 + if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss: + raise ValueError('Commit loss is enabled but the quantizer does not support it.') + # Log setup self.log_config = cfg.get("log_config", None) @@ -353,7 +366,11 @@ def _process_batch(self, batch): encoded = self.encoder_noise(encoded) if self.vector_quantizer: - encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + if self.vector_quantizer_has_commit_loss: + encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + else: + encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) + commit_loss = 0.0 else: commit_loss = 0.0 diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 0e1459002cdc..38c1b8147643 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple +import torch import torch.nn as nn +from einops import rearrange from nemo.collections.asr.parts.utils.activations import Snake from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types.elements import LengthsType, VoidType +from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, VoidType from nemo.core.neural_types.neural_type import NeuralType +from nemo.utils import logging def get_padding(kernel_size: int, dilation: int = 1) -> int: @@ -64,7 +67,7 @@ def forward(self, x): class Conv1dNorm(NeuralModule): def __init__( - self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None + self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: Optional[int] = None, ): super().__init__() if not padding: @@ -181,3 +184,347 @@ def remove_weight_norm(self): @typecheck() def forward(self, inputs): return self.conv(inputs) + + +class FiniteScalarQuantizer(NeuralModule): + """This quantizer is based on the Finite Scalar Quantization (FSQ) method. + It quantizes each element of the input vector independently into a number of levels. + + Args: + num_levels: number of levels for each dimension/element of the input vector + eps: small regularization constant for scaling + + References: + Mentzer et al., Finite Scalar Quantization: VQ-VAE Made Simple (https://arxiv.org/abs/2309.15505v1) + """ + + def __init__(self, num_levels: List[int], eps: float = 1e-3): + super().__init__() + + # index base per dimension of the input vector + # this is used to convert between per-dimension indices and a codebook token index + dim_base_index = torch.cumprod(torch.tensor([1] + num_levels[:-1]), dim=0, dtype=torch.int32) + dim_base_index = rearrange(dim_base_index, 'D -> 1 D 1') + self.register_buffer('dim_base_index', dim_base_index) + + # Register the number of levels for each dimension + num_levels = torch.tensor(num_levels, dtype=torch.int32) + num_levels = rearrange(num_levels, 'D -> 1 D 1') + self.register_buffer('num_levels', num_levels) + + # Regularization + self.eps = eps + + logging.debug('Initializing %s with', self.__class__.__name__) + logging.debug('\tdim: %s', self.dim) + logging.debug('\tnum_levels: %s', self.num_levels) + logging.debug('\tcodebook_size: %s', self.codebook_size) + logging.debug('\teps: %s', self.eps) + + @property + def codebook_size(self): + """Returns the size of the corresponding codebook.""" + return self.num_levels.prod().item() + + @property + def dim(self): + """Returns the dimension of the input vector.""" + return self.num_levels.numel() + + @property + def codebook_dim(self): + """Returns the dimension of the input vector. + Keeping for compatiblitiy with the original RVQ implementation. + """ + return self.dim + + @property + def codes(self): + """Returns the codebooks entries. + + Note that the codebook entries are implicitly defined by the number of levels. + """ + indices = torch.arange(self.codebook_size) + # [D, B, T] + indices = rearrange(indices, 'B -> 1 B 1') + # [B, D, T] + codes = self.decode(indices=indices, input_len=None) + # Remove the time dimension + codes = codes.squeeze(-1) + return codes + + @property + def codebook(self): + """Returns the codebooks entries. + See self.codes for more details. + """ + return self.codes + + @staticmethod + def round(inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Round the input tensor to nearest integer + and use a straight-through estimator for the gradient. + """ + inputs_rounded = torch.round(inputs) + return inputs + (inputs_rounded - inputs).detach() + + def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Apply compression to the input, to limit to values. + """ + output_scale = (self.num_levels - 1) / 2 + # scale down a bit to avoid rounding issues + output_scale = output_scale * (1 - self.eps) + # offset for even number of levels + output_offset = torch.where(self.num_levels % 2 == 0, 0.5, 0) + # shift for even number of levels + input_shift = (output_offset / output_scale).tan() + # compressed output + output = output_scale * (inputs + input_shift).tanh() - output_offset + return output + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"codes": NeuralType(('B', 'D', 'T'), Index())}, + ) + def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + # apply compression + compressed = self.compress(inputs=inputs, input_len=input_len) + # apply rounding to nearest integer + codes = self.round(inputs=compressed, input_len=input_len) + # normalize to [-1, 1] + scale = self.num_levels // 2 + codes = codes / scale + return codes + + def codes_to_nonnegative(self, codes: torch.Tensor) -> torch.Tensor: + """Convert values centered arouund zero to nonnegative values. + """ + scale = offset = self.num_levels // 2 + return scale * codes + offset + + def nonnegative_to_codes(self, codes_nonnegative: torch.Tensor) -> torch.Tensor: + """Convert nonnegative values to values centered arouund zero. + """ + scale = offset = self.num_levels // 2 + return (codes_nonnegative - offset) / scale + + def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: + """Converts a code vector to a single index. + """ + if codes.size(1) != self.dim: + raise RuntimeError( + f'Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}' + ) + # convert code vectors to nonnegative values + indices = self.codes_to_nonnegative(codes) + # convert one nonnegative index per dimension to a single index per code vector + indices = torch.sum(indices * self.dim_base_index, dim=1) + return indices.to(torch.int32) + + # API of the RVQ + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self): + return { + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('D', 'B', 'T'), Index()), + } + + @typecheck() + def forward( + self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if inputs.size(1) != self.dim: + raise RuntimeError( + f'Input dimension {inputs.size(1)} not matching the expected dimension {self.dim}, inputs shape {inputs.shape}' + ) + + dequantized = self.inputs_to_codes(inputs=inputs, input_len=input_len) + indices = self.codes_to_indices(codes=dequantized) + + if input_len is not None: + # apply masking + dequantized = mask_sequence_tensor(dequantized, input_len) + indices = mask_sequence_tensor(indices, input_len) + + # only 1 codebook, but return in [D, B, T] format to match RVQ API + indices = indices.unsqueeze(0) + return dequantized, indices + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + ) + def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: + """Convert a continuous code vector to a single index. + """ + _, indices = self(inputs=inputs, input_len=input_len) + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('D', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: + """Convert a single index to a continuous code vector. + """ + if indices.size(0) > 1: + # codebook dimension used for compatibility with RVQ + raise ValueError( + f'Expected a single codebook, got {indices.size(0)} codebooks for indices with shape {indices.shape}.' + ) + + indices = rearrange(indices, 'D B T -> B D T') + # convert a single index to nonnegative index per-dimension + codes_nonnegative = (indices // self.dim_base_index) % self.num_levels + # convert nonnegative codes to codes (centered around zero) + dequantized = self.nonnegative_to_codes(codes_nonnegative) + + if input_len is not None: + # apply masking + dequantized = mask_sequence_tensor(dequantized, input_len) + return dequantized + + +class GroupFiniteScalarQuantizer(NeuralModule): + """Split the input vector into groups and apply FSQ on each group separately. + This class is for convenience. Since FSQ is applied on each group separately, + groups can be defined arbitrarily by splitting the input vector. However, this + class makes it easy to construct several groups with the same quantization num_levels. + + Args: + num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks + codebook_dim: embedding dimension, will be split into num_groups + **kwargs: parameters of FiniteScalarQuantizer + + References: + Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765). + """ + + def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): + super().__init__() + + self.num_groups = num_groups + self.codebook_dim_per_group = len(num_levels_per_group) + + # Initialize FSQ for each group + self.fsqs = torch.nn.ModuleList( + [FiniteScalarQuantizer(num_levels=num_levels_per_group, **kwargs) for _ in range(self.num_groups)] + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_groups: %d', self.num_groups) + logging.debug('\tcodebook_dim: %d', self.codebook_dim) + logging.debug('\tnum_levels_per_group: %s', num_levels_per_group) + logging.debug('\tcodebook_dim_per_group: %d', self.codebook_dim_per_group) + + @property + def codebook_dim(self): + """Input vector dimension. + """ + return self.codebook_dim_per_group * self.num_groups + + @property + def codebook_size_per_group(self): + """Returns the size of the implicit codebook for each group.""" + return self.fsqs[0].codebook_size + + @property + def codebook_size(self): + """Returns the size of the implicit codebook.""" + return self.codebook_size_per_group ** self.num_groups + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + return { + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "indices": NeuralType(('D', 'B', 'T'), Index()), + } + + @typecheck() + def forward(self, inputs, input_len): + """Quantize each group separately, then concatenate the results. + """ + inputs_grouped = inputs.chunk(self.num_groups, dim=1) + + dequantized, indices = [], [] + + for in_group, fsq_group in zip(inputs_grouped, self.fsqs): + dequantized_group, indices_group = fsq_group(inputs=in_group, input_len=input_len) + dequantized.append(dequantized_group) + indices.append(indices_group) + + # concatenate along the feature dimension + dequantized = torch.cat(dequantized, dim=1) + + # concatente along the codebook dimension + indices = torch.cat(indices, dim=0) + + return dequantized, indices + + @typecheck( + input_types={ + "inputs": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, + ) + def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Input is split into groups, each group is encoded separately, then the results are concatenated. + """ + inputs_grouped = inputs.chunk(self.num_groups, dim=1) + indices = [] + + for in_group, fsq_group in zip(inputs_grouped, self.fsqs): + indices_group = fsq_group.encode(inputs=in_group, input_len=input_len) + indices.append(indices_group) + + # concatenate along the codebook dimension + indices = torch.cat(indices, dim=0) + + return indices + + @typecheck( + input_types={ + "indices": NeuralType(('D', 'B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, + output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + ) + def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: + """Input indices are split into groups, each group is decoded separately, then the results are concatenated. + """ + indices_grouped = indices.chunk(self.num_groups, dim=0) + dequantized = [] + + for indices_group, fsq_group in zip(indices_grouped, self.fsqs): + dequantized_group = fsq_group.decode(indices=indices_group, input_len=input_len) + dequantized.append(dequantized_group) + + # concatenate along the feature dimension + dequantized = torch.cat(dequantized, dim=1) + + return dequantized diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index c2d5d9ffddc8..1719dcc15179 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -265,7 +265,7 @@ def __init__( out_channels = in_channels // 2 kernel_size = 2 * up_sample_rate up_sample_conv = ConvTranspose1dNorm( - in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate, ) in_channels = out_channels self.up_sample_conv_layers.append(up_sample_conv) diff --git a/tests/collections/tts/modules/test_audio_codec_modules.py b/tests/collections/tts/modules/test_audio_codec_modules.py index 5f9abb1831c5..6d4f88c0f417 100644 --- a/tests/collections/tts/modules/test_audio_codec_modules.py +++ b/tests/collections/tts/modules/test_audio_codec_modules.py @@ -19,6 +19,8 @@ CodecActivation, Conv1dNorm, ConvTranspose1dNorm, + FiniteScalarQuantizer, + GroupFiniteScalarQuantizer, get_down_sample_padding, ) from nemo.collections.tts.modules.encodec_modules import GroupResidualVectorQuantizer, ResidualVectorQuantizer @@ -205,3 +207,165 @@ def test_snake(self): snake = CodecActivation('snake', channels=self.in_channels) out = snake(x=inputs) assert out.shape == (self.batch_size, self.in_channels, self.max_len) + + +class TestFiniteScalarQuantizer: + def setup_class(self): + """Setup common members + """ + self.batch_size = 2 + self.max_len = 20 + self.num_examples = 10 + + @pytest.mark.unit + @pytest.mark.parametrize('num_levels', [[2, 3], [8, 5, 5]]) + def test_fsq_eval(self, num_levels: list): + """Simple test to confirm that the FSQ module can be instantiated and run, + and that forward produces the same result as encode-decode. + """ + fsq = FiniteScalarQuantizer(num_levels=num_levels) + + for i in range(self.num_examples): + inputs = torch.randn([self.batch_size, fsq.codebook_dim, self.max_len]) + input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) + + # apply forward + dequantized_fw, indices_fw = fsq(inputs=inputs, input_len=input_len) + + assert dequantized_fw.max() <= 1.0, f'example {i}: dequantized_fw.max() is {dequantized_fw.max()}' + assert dequantized_fw.min() >= -1.0, f'example {i}: dequantized_fw.min() is {dequantized_fw.min()}' + + # encode-decode + indices_enc = fsq.encode(inputs=inputs, input_len=input_len) + dequantized_dec = fsq.decode(indices=indices_enc, input_len=input_len) + + # make sure the results are the same + torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') + torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') + + @pytest.mark.unit + def test_fsq_output(self): + """Simple test to make sure the output of FSQ is correct for a single setup. + + To re-generate test vectors: + ``` + num_examples, max_len = 5, 8 + inputs = torch.randn([num_examples, fsq.codebook_dim, max_len]) + input_len = torch.tensor([max_len] * num_examples, dtype=torch.int32) + dequantized, indices = fsq(inputs=inputs, input_len=input_len) + ``` + """ + num_levels = [3, 4] + fsq = FiniteScalarQuantizer(num_levels=num_levels) + + # inputs + inputs = torch.tensor( + [ + [ + [0.1483, -0.3855, -0.3715, -0.5913, -0.2212, -0.4226, -0.4864, -1.6069], + [-0.5519, -0.5307, -0.5995, -1.9675, -0.4439, 0.3938, -0.5636, -0.3655], + ], + [ + [0.5184, 1.4028, 0.1553, -0.2324, 1.0363, -0.4981, -0.1203, -1.0335], + [-0.1567, -0.2274, 0.0424, -0.0819, -0.2122, -2.1851, -1.5035, -1.2237], + ], + [ + [0.9497, 0.8510, -1.2021, 0.3299, -0.2388, 0.8445, 2.2129, -2.3383], + [1.5331, 0.0399, -0.7676, -0.4715, -0.5713, 0.8761, -0.9755, -0.7479], + ], + [ + [1.7243, -1.2146, -0.1969, 1.9261, 0.1109, 0.4028, 0.1240, -0.0994], + [-0.3304, 2.1239, 0.1004, -1.4060, 1.1463, -0.0557, -0.5856, -1.2441], + ], + [ + [2.3743, -0.1421, -0.4548, 0.6320, -0.2640, -0.3967, -2.5694, 0.0493], + [0.3409, 0.2366, -0.0309, -0.7652, 0.3484, -0.8419, 0.9079, -0.9929], + ], + ] + ) + + input_len = torch.tensor([8, 8, 8, 8, 8], dtype=torch.int32) + + # expected output + dequantized_expected = torch.tensor( + [ + [ + [0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000, 0.0000, -1.0000], + [-0.5000, -0.5000, -0.5000, -1.0000, -0.5000, 0.0000, -0.5000, -0.5000], + ], + [ + [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.0000, -1.0000, -1.0000], + ], + [ + [1.0000, 1.0000, -1.0000, 0.0000, 0.0000, 1.0000, 1.0000, -1.0000], + [0.5000, 0.0000, -0.5000, -0.5000, -0.5000, 0.5000, -0.5000, -0.5000], + ], + [ + [1.0000, -1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.0000, -1.0000, 0.5000, 0.0000, -0.5000, -1.0000], + ], + [ + [1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000, 0.0000], + [0.0000, 0.0000, 0.0000, -0.5000, 0.0000, -0.5000, 0.5000, -0.5000], + ], + ] + ) + + indices_expected = torch.tensor( + [ + [ + [4, 4, 4, 0, 4, 7, 4, 3], + [7, 8, 7, 7, 8, 1, 1, 0], + [11, 8, 3, 4, 4, 11, 5, 3], + [8, 9, 7, 2, 10, 7, 4, 1], + [8, 7, 7, 5, 7, 4, 9, 4], + ] + ], + dtype=torch.int32, + ) + + # test + dequantized, indices = fsq(inputs=inputs, input_len=input_len) + torch.testing.assert_close(dequantized, dequantized_expected, msg=f'dequantized mismatch') + torch.testing.assert_close(indices, indices_expected, msg=f'indices mismatch') + + @pytest.mark.unit + @pytest.mark.parametrize('num_groups', [1, 2, 4]) + @pytest.mark.parametrize('num_levels_per_group', [[2, 3], [8, 5, 5]]) + def test_group_fsq_eval(self, num_groups: int, num_levels_per_group: int): + """Simple test to confirm that the group FSQ module can be instantiated and run, + and that forward produces the same result as encode-decode. + """ + # Test inference with group FSQ + # instantiate + gfsq = GroupFiniteScalarQuantizer(num_groups=num_groups, num_levels_per_group=num_levels_per_group) + + for i in range(self.num_examples): + inputs = torch.randn([self.batch_size, gfsq.codebook_dim, self.max_len]) + input_len = torch.tensor([self.max_len] * self.batch_size, dtype=torch.int32) + + # apply forward + dequantized_fw, indices_fw = gfsq(inputs=inputs, input_len=input_len) + + # encode-decode + indices_enc = gfsq.encode(inputs=inputs, input_len=input_len) + dequantized_dec = gfsq.decode(indices=indices_enc, input_len=input_len) + + # make sure the results are the same + torch.testing.assert_close(indices_enc, indices_fw, msg=f'example {i}: indices mismatch') + torch.testing.assert_close(dequantized_dec, dequantized_fw, msg=f'example {i}: dequantized mismatch') + + # apply individual FSQs and make sure the results are the same + inputs_grouped = inputs.chunk(num_groups, dim=1) + dequantized_fw_grouped = dequantized_fw.chunk(num_groups, dim=1) + indices_fw_grouped = indices_fw.chunk(num_groups, dim=0) + + for g in range(num_groups): + dequantized, indices = gfsq.fsqs[g](inputs=inputs_grouped[g], input_len=input_len) + torch.testing.assert_close( + dequantized, dequantized_fw_grouped[g], msg=f'example {i}: dequantized mismatch for group {g}' + ) + torch.testing.assert_close( + indices, indices_fw_grouped[g], msg=f'example {i}: indices mismatch for group {g}' + )