Skip to content

Commit

Permalink
Updated test
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Nov 16, 2023
1 parent b483fb4 commit 6d56bb9
Showing 1 changed file with 43 additions and 61 deletions.
104 changes: 43 additions & 61 deletions tests/collections/tts/modules/test_audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,99 +245,81 @@ def test_fsq_eval(self, num_levels: list):

@pytest.mark.unit
def test_fsq_output(self):
"""Simple test to make sure the output of FSQ is correct
for a single setup.
"""Simple test to make sure the output of FSQ is correct for a single setup.
To re-generate test vectors:
```
num_examples, max_len = 5, 8
inputs = torch.randn([num_examples, fsq.codebook_dim, max_len])
input_len = torch.tensor([max_len] * num_examples, dtype=torch.int32)
dequantized, indices = fsq(inputs=inputs, input_len=input_len)
```
"""
num_levels = [2, 3]
num_levels = [3, 4]
fsq = FiniteScalarQuantizer(num_levels=num_levels)

# # To generate inputs & outputs for testing
# max_len = 8
# inputs = torch.randn([self.num_examples, fsq.codebook_dim, max_len])
# input_len = torch.tensor([max_len] * self.num_examples, dtype=torch.int32)
# dequantized, indices = fsq(inputs=inputs, input_len=input_len)
# print(inputs)
# print(input_len)
# print(dequantized)
# print(indices)

# inputs
inputs = torch.tensor(
[
[
[0.6572, 1.3574, 2.1646, 0.9457, -0.3489, 0.6732, -0.7148, -0.5143],
[0.5123, 0.8552, 1.7814, 1.9938, -1.1909, -0.9991, -3.7932, -0.4438],
[0.1483, -0.3855, -0.3715, -0.5913, -0.2212, -0.4226, -0.4864, -1.6069],
[-0.5519, -0.5307, -0.5995, -1.9675, -0.4439, 0.3938, -0.5636, -0.3655],
],
[
[0.2357, 0.8324, 0.8932, -0.0596, 0.6130, -0.0299, 0.3824, 1.6278],
[-0.3781, 0.1864, -0.2190, 1.2199, -1.1398, -0.8443, -0.7865, 0.1470],
[0.5184, 1.4028, 0.1553, -0.2324, 1.0363, -0.4981, -0.1203, -1.0335],
[-0.1567, -0.2274, 0.0424, -0.0819, -0.2122, -2.1851, -1.5035, -1.2237],
],
[
[0.9786, 1.2170, 0.2229, -0.6481, -0.0348, 0.0552, 0.3956, -1.0916],
[1.2982, 0.5188, 0.3546, -1.5305, -1.0674, 1.1292, 0.7662, -0.1397],
[0.9497, 0.8510, -1.2021, 0.3299, -0.2388, 0.8445, 2.2129, -2.3383],
[1.5331, 0.0399, -0.7676, -0.4715, -0.5713, 0.8761, -0.9755, -0.7479],
],
[
[-1.8530, -0.3099, 1.7705, 0.2201, 0.3348, -0.2126, 0.4756, -0.9759],
[-0.6812, 0.8368, -0.5181, -0.6713, -0.0681, -1.5496, 0.9230, 0.3448],
[1.7243, -1.2146, -0.1969, 1.9261, 0.1109, 0.4028, 0.1240, -0.0994],
[-0.3304, 2.1239, 0.1004, -1.4060, 1.1463, -0.0557, -0.5856, -1.2441],
],
[
[-1.1940, -1.3429, 0.5648, 1.5043, -1.0501, -1.0594, -1.0261, -0.2600],
[-1.3521, -0.4740, -1.6166, -0.8975, 0.6101, 0.2225, 1.0959, -0.0723],
[2.3743, -0.1421, -0.4548, 0.6320, -0.2640, -0.3967, -2.5694, 0.0493],
[0.3409, 0.2366, -0.0309, -0.7652, 0.3484, -0.8419, 0.9079, -0.9929],
],
]
)

input_len = torch.tensor([8, 8, 8, 8, 8], dtype=torch.int32)

# expected output
dequantized_expected = torch.tensor(
[
[
[-1.4756, 0.0630, -1.9273, 0.6048, 0.0432, -0.4550, -0.9183, 1.4493],
[-0.8326, 0.4620, 1.8287, 0.2323, -1.5944, -0.8721, 0.0126, -2.1843],
[0.0000, 0.0000, 0.0000, -1.0000, 0.0000, 0.0000, 0.0000, -1.0000],
[-0.5000, -0.5000, -0.5000, -1.0000, -0.5000, 0.0000, -0.5000, -0.5000],
],
[
[0.0913, -0.4664, 0.2600, 0.7711, 0.7383, 0.8726, 0.1065, 0.5777],
[-0.9217, 1.9214, -0.4060, 1.5786, 0.5549, 1.6364, 0.2880, -0.7962],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.0000, -1.0000, -1.0000],
],
[
[-0.5967, -1.0998, -0.2475, 1.2475, -2.1949, 0.0607, 0.5634, 1.0397],
[0.2047, 0.3775, -0.1769, 0.7248, -1.6236, 0.5641, 0.9344, 0.2959],
[1.0000, 1.0000, -1.0000, 0.0000, 0.0000, 1.0000, 1.0000, -1.0000],
[0.5000, 0.0000, -0.5000, -0.5000, -0.5000, 0.5000, -0.5000, -0.5000],
],
[
[0.2875, 0.6632, -0.6974, -0.0710, 0.1296, 0.0872, -1.0767, 1.3350],
[0.7032, 0.6264, -0.3976, -0.0257, -0.5352, -1.1119, -1.0472, 0.1626],
[1.0000, -1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.0000, -1.0000, 0.5000, 0.0000, -0.5000, -1.0000],
],
[
[0.2453, -0.9992, 0.3687, 1.5681, -0.3206, 0.2046, -0.4617, 0.7606],
[-1.8594, 0.2092, 0.1827, -1.1598, 0.0664, 0.9267, 1.5458, 1.5875],
[1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, -1.0000, 0.0000],
[0.0000, 0.0000, 0.0000, -0.5000, 0.0000, -0.5000, 0.5000, -0.5000],
],
]
)

input_len = torch.tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], dtype=torch.int32)

# expected output
dequantized_expected = torch.tensor(
[
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, -1.0, -1.0, -1.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, -1.0, -1.0, 1.0, 1.0, 0.0]],
[[-1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.0, 1.0, 0.0, -1.0, 0.0, -1.0, 1.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.0, 0.0, -1.0, -1.0, 1.0, 0.0, 1.0, 0.0]],
[[0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.0, 0.0, 1.0, 0.0, -1.0, -1.0, 0.0, -1.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, -1.0]],
[[0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, -1.0, 1.0, 1.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-1.0, 0.0, 0.0, -1.0, 0.0, 1.0, 1.0, 1.0]],
]
)

indices_expected = torch.tensor(
[
[
[3, 5, 5, 5, 1, 1, 1, 3],
[3, 3, 3, 5, 1, 1, 1, 3],
[5, 3, 3, 1, 1, 5, 5, 3],
[0, 5, 3, 1, 3, 1, 5, 3],
[1, 3, 1, 1, 5, 3, 5, 3],
[1, 3, 4, 3, 1, 1, 3, 1],
[1, 5, 3, 5, 5, 5, 3, 1],
[3, 3, 3, 5, 0, 5, 5, 3],
[5, 5, 3, 3, 3, 1, 1, 3],
[1, 3, 3, 1, 3, 5, 5, 5],
[4, 4, 4, 0, 4, 7, 4, 3],
[7, 8, 7, 7, 8, 1, 1, 0],
[11, 8, 3, 4, 4, 11, 5, 3],
[8, 9, 7, 2, 10, 7, 4, 1],
[8, 7, 7, 5, 7, 4, 9, 4],
]
],
dtype=torch.int32,
Expand Down

0 comments on commit 6d56bb9

Please sign in to comment.