Skip to content

Commit

Permalink
[Codec] Finite scalar quantizer (NVIDIA#7886)
Browse files Browse the repository at this point in the history
* Finite scalar quantizer

Signed-off-by: Ante Jukić <[email protected]>

* Updated test

Signed-off-by: Ante Jukić <[email protected]>

---------

Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Nov 17, 2023
1 parent 87e8b75 commit 1fd6431
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 5 deletions.
19 changes: 18 additions & 1 deletion nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

if "vector_quantizer" in cfg:
self.vector_quantizer = instantiate(cfg.vector_quantizer)

vq_output_types = list(self.vector_quantizer.output_types.keys())

if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss':
self.vector_quantizer_has_commit_loss = True
logging.info('Vector quantizer supports commit loss.')
else:
self.vector_quantizer_has_commit_loss = False
logging.info('Vector quantizer does not support commit loss.')

else:
logging.warning('Vector quantizer will not be used.')
self.vector_quantizer = None
Expand Down Expand Up @@ -124,6 +134,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
else:
self.commit_loss_scale = 0.0

if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss:
raise ValueError('Commit loss is enabled but the quantizer does not support it.')

# Log setup
self.log_config = cfg.get("log_config", None)

Expand Down Expand Up @@ -353,7 +366,11 @@ def _process_batch(self, batch):
encoded = self.encoder_noise(encoded)

if self.vector_quantizer:
encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
if self.vector_quantizer_has_commit_loss:
encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
else:
encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
commit_loss = 0.0
else:
commit_loss = 0.0

Expand Down
Loading

0 comments on commit 1fd6431

Please sign in to comment.