Skip to content

Commit

Permalink
Handling invalid audio generations (for DPO) (#43)
Browse files Browse the repository at this point in the history
* test_step(): handle invalid audio

When the model generates an output that is very short (less than 2 ASR frames) the ASR and SSIM calculations will error out. We detect the error and invalidate the entire batch, setting WER/CER to 100% and SSIM to 0.0. The transcription is set to "<INVALID">.

Note the metrics still written out to the `.metrics` files; they need to be ignored by any subsequent statistics calculations.

* DPO: changes to preference pair creation

1. Skip groups that have any invalid records.
2. Allow the number of records to exactly match the number of audio files (vs requiring it to be strictly smaller).
3. Add `tqdm` to incidatea progress during long loops.

* Comment

* Fix merge issues and a bug

Refining the handling of invalid entries in DPO preference selection.

* Fix merge issues
  • Loading branch information
rfejgin authored Jan 30, 2025
1 parent 3f69088 commit d70b903
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
63 changes: 40 additions & 23 deletions nemo/collections/tts/models/t5tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,7 @@ def test_step(self, batch, batch_idx):
)
predicted_audio_paths = []
audio_durations = []
batch_invalid = False
for idx in range(predicted_audio.size(0)):
predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]
Expand All @@ -1079,33 +1080,49 @@ def test_step(self, batch, batch_idx):
predicted_codes_torch = predicted_codes_torch[:, :predicted_codes_lens[idx]]
torch.save(predicted_codes_torch, os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'))
predicted_audio_paths.append(audio_path)

with torch.no_grad():
if self.cfg.get("pref_set_language", "en") == "en":
pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths))[0]
pred_transcripts = [ self.process_text(transcript) for transcript in pred_transcripts ]
else:
pred_transcripts = [self.transcribe_with_whisper(audio_path, self.cfg.pref_set_language) for audio_path in predicted_audio_paths]
pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts]
pred_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(predicted_audio_paths)
gt_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(batch['audio_filepaths'])

if not batch_invalid:
with torch.no_grad():
try:
if self.cfg.get("pref_set_language", "en") == "en":
pred_transcripts = self.eval_asr_model.transcribe(predicted_audio_paths, batch_size=len(predicted_audio_paths))[0]
pred_transcripts = [ self.process_text(transcript) for transcript in pred_transcripts ]
else:
pred_transcripts = [self.transcribe_with_whisper(audio_path, self.cfg.pref_set_language) for audio_path in predicted_audio_paths]
pred_transcripts = [self.process_text(transcript) for transcript in pred_transcripts]
except Exception as e:
assert (predicted_audio_lens[idx] < 1000).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}"
logging.warning(f"Exception during ASR transcription: {e}")
logging.warning(f"Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0")
batch_invalid = True
continue # don't break since we want to continue building audio durations list
pred_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(predicted_audio_paths)
gt_speaker_embeddings = self.get_speaker_embeddings_from_filepaths(batch['audio_filepaths'])

for idx in range(predicted_audio.size(0)):
audio_path = predicted_audio_paths[idx]
item_idx = batch_idx * test_dl_batch_size + idx
pred_transcript = pred_transcripts[idx]
gt_transcript = self.process_text(batch['raw_texts'][idx])
if not batch_invalid:
audio_path = predicted_audio_paths[idx]
item_idx = batch_idx * test_dl_batch_size + idx
pred_transcript = pred_transcripts[idx]
gt_transcript = self.process_text(batch['raw_texts'][idx])

cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True)
wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False)
cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True)
wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False)

spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy()
spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy()

spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / (
np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt)
)
else:
# Create an entry indicating invalid metrics
cer_gt = 1.0
wer_gt = 1.0
spk_similarity = 0.0
pred_transcript = "<INVALID>"
gt_transcript = self.process_text(batch['raw_texts'][idx])

spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy()
spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy()

spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / (
np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt)
)

item_metrics = {
'cer_gt': float(cer_gt),
'wer_gt': float(wer_gt),
Expand Down
14 changes: 12 additions & 2 deletions scripts/t5tts/dpo/create_preference_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import random
import math
from tqdm import tqdm

def main():
parser = argparse.ArgumentParser()
Expand All @@ -18,7 +19,7 @@ def main():
audio_files, codec_files, metric_files = find_audio_files(args.generated_audio_dir)
assert len(records) <= len(audio_files), "Mismatch between number of records and number of generated audio files {} vs {}".format(len(records), len(audio_files))

for idx, record in enumerate(records):
for idx, record in tqdm(enumerate(records)):
if idx % 100 == 0:
print("At idx: ", idx, len(records))
record['audio_filepath'] = audio_files[idx]
Expand Down Expand Up @@ -187,6 +188,7 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr
num_groups = len(records) // group_size
best_records = []
worst_records = []
num_skipped = 0

if num_chosen_per_group == 1:
chosen_group_indices = [0]
Expand All @@ -203,9 +205,16 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr
group = records[gsi:gei]

cer_sim_indices = []
skip_group = False
for sidx, record in enumerate(group):
if record['pred_transcript'] == "<INVALID>":
print(f"Skipping group starting at index {gsi} due to invalid entries.")
num_skipped += len(group)
skip_group = True
break
cer_sim_indices.append((record['cer_gts'], record['pred_context_similarity'], sidx))

if skip_group:
continue
cer_sim_indices_orig = copy.deepcopy(cer_sim_indices)
cer_sim_indices = pareto_rank(cer_sim_indices)

Expand All @@ -228,6 +237,7 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr
best_records.append(best_record)
worst_records.append(worst_record)

print(f"Skipped {num_skipped} records due to invalid entries.")
return best_records, worst_records

def filter_best_and_worst_records(best_records, worst_records, cer_threshold=0.02):
Expand Down

0 comments on commit d70b903

Please sign in to comment.