Skip to content

Commit

Permalink
[TTS] Added a callback for logging initial data (#7384)
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Sep 8, 2023
1 parent 0a05a23 commit 2f2e47d
Showing 1 changed file with 95 additions and 16 deletions.
111 changes: 95 additions & 16 deletions nemo/collections/tts/parts/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ def create_id(filepath: Path) -> str:
class ArtifactGenerator(ABC):
@abstractmethod
def generate_artifacts(
self, model: LightningModule, batch_dict: Dict
self, model: LightningModule, batch_dict: Dict, initial_log: bool = False
) -> Tuple[List[AudioArtifact], List[ImageArtifact]]:
"""
Create artifacts for the input model and test batch.
Args:
model: Model instance being trained to use for inference.
batch_dict: Test batch to generate artifacts for.
initial_log: Flag to denote if this is the initial log, can
be used to save ground-truth data only once.
Returns:
List of audio and image artifacts to log.
Expand Down Expand Up @@ -214,17 +216,48 @@ def _log_image(self, image: ImageArtifact, log_dir: Path, step: int):
wandb_image = (wandb.Image(image_plot, caption=image.id),)
self.wandb_logger.log({image.id: wandb_image})

def _log_artifacts(self, audio_list: list, image_list: list, log_dir: Optional[Path] = None, global_step: int = 0):
"""Log audio and image artifacts.
"""
if log_dir is not None:
log_dir.mkdir(parents=True, exist_ok=True)

for audio in audio_list:
self._log_audio(audio=audio, log_dir=log_dir, step=global_step)

for image in image_list:
self._log_image(image=image, log_dir=log_dir, step=global_step)

def on_fit_start(self, trainer: Trainer, model: LightningModule):
"""Log initial data artifacts.
"""
audio_list = []
image_list = []
for batch_dict in self.data_loader:
for key, value in batch_dict.items():
if isinstance(value, torch.Tensor):
batch_dict[key] = value.to(model.device)

for generator in self.generators:
audio, images = generator.generate_artifacts(model=model, batch_dict=batch_dict, initial_log=True)
audio_list += audio
image_list += images

if len(audio_list) == len(image_list) == 0:
logging.debug('List are empty, no initial artifacts to log.')
return

log_dir = self.output_dir / f"initial" if self.output_dir else None

self._log_artifacts(audio_list=audio_list, image_list=image_list, log_dir=log_dir)

def on_train_epoch_end(self, trainer: Trainer, model: LightningModule):
"""Log artifacts at the end of an epoch.
"""
epoch = 1 + model.current_epoch
if (epoch not in self.log_epochs) and (epoch % self.epoch_frequency != 0):
return

if self.output_dir:
log_dir = self.output_dir / f"epoch_{epoch}"
log_dir.mkdir(parents=True, exist_ok=True)
else:
log_dir = None

audio_list = []
image_list = []
for batch_dict in self.data_loader:
Expand All @@ -237,11 +270,13 @@ def on_train_epoch_end(self, trainer: Trainer, model: LightningModule):
audio_list += audio
image_list += images

for audio in audio_list:
self._log_audio(audio=audio, log_dir=log_dir, step=model.global_step)
if len(audio_list) == len(image_list) == 0:
logging.debug('List are empty, no artifacts to log at epoch %d.', epoch)
return

for image in image_list:
self._log_image(image=image, log_dir=log_dir, step=model.global_step)
log_dir = self.output_dir / f"epoch_{epoch}" if self.output_dir else None

self._log_artifacts(audio_list=audio_list, image_list=image_list, log_dir=log_dir)


class VocoderArtifactGenerator(ArtifactGenerator):
Expand All @@ -250,8 +285,11 @@ class VocoderArtifactGenerator(ArtifactGenerator):
"""

def generate_artifacts(
self, model: LightningModule, batch_dict: Dict
self, model: LightningModule, batch_dict: Dict, initial_log: bool = False
) -> Tuple[List[AudioArtifact], List[ImageArtifact]]:
if initial_log:
# Currently, nothing to log before training starts
return [], []

audio_artifacts = []

Expand Down Expand Up @@ -297,7 +335,16 @@ def __init__(self, log_audio: bool = True, log_encoding: bool = False, log_dequa
logging.debug('\tlog_encoding: %s', self.log_encoding)
logging.debug('\tlog_dequantized: %s', self.log_dequantized)

def _generate_audio(self, model, audio_ids, audio, audio_len):
def _generate_audio(self, model, audio_ids, audio, audio_len, save_input: bool = False):
"""Generate audio artifacts.
Args:
model: callable model, outputs (audio_pred, audio_pred_len)
audio_ids: list of IDs for the examples in audio batch
audio: tensor of input audio signals, shape (B, T)
audio_len: tensor of lengths for each example in the batch, shape (B,)
save_input: if True, save input audio signals
"""
if not self.log_audio:
return []

Expand All @@ -317,9 +364,29 @@ def _generate_audio(self, model, audio_ids, audio, audio_len):
)
audio_artifacts.append(audio_artifact)

if save_input:
# save input audio
for i, audio_id in enumerate(audio_ids):
audio_in_i = audio[i, : audio_len[i]].cpu().numpy()
audio_artifact = AudioArtifact(
id=f"audio_in_{audio_id}",
data=audio_in_i,
filename=f"{audio_id}_audio_in.wav",
sample_rate=model.sample_rate,
)
audio_artifacts.append(audio_artifact)

return audio_artifacts

def _generate_images(self, model, audio_ids, audio, audio_len):
"""Generate image artifacts.
Args:
model: model, needs to support `model.encode_audio`, `model.quantize` and `model.dequantize`
audio_ids: list of IDs for the examples in audio batch
audio: tensor of input audio signals, shape (B, T)
audio_len: tensor of lengths for each example in the batch, shape (B,)
"""
image_artifacts = []

if not self.log_encoding and not self.log_dequantized:
Expand Down Expand Up @@ -363,16 +430,24 @@ def _generate_images(self, model, audio_ids, audio, audio_len):
return image_artifacts

def generate_artifacts(
self, model: LightningModule, batch_dict: Dict
self, model: LightningModule, batch_dict: Dict, initial_log: bool = False
) -> Tuple[List[AudioArtifact], List[ImageArtifact]]:
"""
Args:
model: model used to process input to generate artifacts
batch_dict: dictionary obtained form the dataloader
initial_log: save input audio for the initial log
"""

audio_filepaths = batch_dict.get("audio_filepaths")
audio_ids = [create_id(p) for p in audio_filepaths]

audio = batch_dict.get("audio")
audio_len = batch_dict.get("audio_lens")

audio_artifacts = self._generate_audio(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len)
audio_artifacts = self._generate_audio(
model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len, save_input=initial_log
)
image_artifacts = self._generate_images(model=model, audio_ids=audio_ids, audio=audio, audio_len=audio_len)

return audio_artifacts, image_artifacts
Expand Down Expand Up @@ -518,9 +593,13 @@ def _generate_gta_predictions(self, model: LightningModule, audio_ids: List[str]
return audio_artifacts, image_artifacts

def generate_artifacts(
self, model: LightningModule, batch_dict: Dict
self, model: LightningModule, batch_dict: Dict, initial_log: bool = False
) -> Tuple[List[AudioArtifact], List[ImageArtifact]]:

if initial_log:
# Currently, nothing to log before training starts
return [], []

audio_artifacts = []
image_artifacts = []
audio_filepaths = batch_dict.get("audio_filepaths")
Expand Down

0 comments on commit 2f2e47d

Please sign in to comment.