Skip to content

Commit

Permalink
Dporpo bugfix (#46)
Browse files Browse the repository at this point in the history
* bug fix in pareto ranking!

Signed-off-by: Shehzeen Hussain <[email protected]>

* ensure best record is never worse than worst record on any metric

Signed-off-by: Shehzeen Hussain <[email protected]>

* reward normalization using z-scores

Signed-off-by: Shehzeen Hussain <[email protected]>

* added comment in notebook for multilingual model

Signed-off-by: Shehzeen Hussain <[email protected]>

---------

Signed-off-by: Shehzeen Hussain <[email protected]>
  • Loading branch information
shehzeen authored Jan 28, 2025
1 parent ba37cb6 commit 3f69088
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
36 changes: 25 additions & 11 deletions scripts/t5tts/dpo/create_preference_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import copy
import random
import math

def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -48,10 +49,7 @@ def main():
best_records, worst_records = filter_best_and_worst_records(all_best_records, all_worst_records, args.cer_threshold)
print("Len filtered best_records: ", len(best_records))
print("Len filtered worst_records: ", len(worst_records))
mean_reward_rejected = sum([record['reward'] for record in worst_records]) / len(worst_records)
print("Mean reward rejected: ", mean_reward_rejected)
for record in worst_records:
record['reward'] = normalize_rejected_reward(record['reward'], mean_reward_rejected)
worst_records = normalize_rejected_rewards(worst_records)
paired_records = [(best_record, worst_record) for best_record, worst_record in zip(best_records, worst_records)]
random.shuffle(paired_records)

Expand Down Expand Up @@ -161,13 +159,27 @@ def is_dominated(A, B):

return ranked_items

def normalize_rejected_reward(reward, mean_negative_reward=0.80):
# if reward < mean, negative reward
# if reward > mean, positive reward
# if reward == mean, 0 reward
normalized_reward = (reward - mean_negative_reward) / (1.0 - mean_negative_reward)
assert normalized_reward < 1.0
return max(-1.0, normalized_reward)

def standard_normal_cdf(z):
"""
Compute the standard normal cumulative distribution function (CDF) for a given z-score.
"""
return 0.5 * (1 + math.erf(z / math.sqrt(2)))

def normalize_rejected_rewards(worst_records):
cer_deltas = [record['cer_delta'] for record in worst_records]
sim_deltas = [record['sim_delta'] for record in worst_records]
cer_mean = sum(cer_deltas) / len(cer_deltas)
cer_std = math.sqrt(sum([(d - cer_mean) ** 2 for d in cer_deltas]) / len(cer_deltas))
sim_mean = sum(sim_deltas) / len(sim_deltas)
sim_std = math.sqrt(sum([(d - sim_mean) ** 2 for d in sim_deltas]) / len(sim_deltas))

for record in worst_records:
cer_z_score = (record['cer_delta'] - cer_mean) / cer_std
sim_z_score = (record['sim_delta'] - sim_mean) / sim_std
record['reward'] = 1.0 - (standard_normal_cdf(cer_z_score) + standard_normal_cdf(sim_z_score)) # Range -1 to 1

return worst_records

def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_group=1):
records = copy.deepcopy(records_orig)
Expand Down Expand Up @@ -211,6 +223,8 @@ def create_chosen_rejected_records(records_orig, group_size=6, num_chosen_per_gr
# Never add pairs in which rejected has better CER than chosen or better context similarity
reward_delta = max(0.001, reward_delta)
worst_record['reward'] = 1.0 - reward_delta
worst_record['cer_delta'] = worst_record['cer_gts'] - best_record['cer_gts']
worst_record['sim_delta'] = best_record['pred_context_similarity'] - worst_record['pred_context_similarity']
best_records.append(best_record)
worst_records.append(worst_record)

Expand Down
3 changes: 2 additions & 1 deletion t5tts_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@
" feature_dir=audio_base_dir,\n",
" text=entry['text'],\n",
" speaker=None,\n",
" speaker_index=0\n",
" speaker_index=0,\n",
" tokenizer_names=[\"english_phoneme\"], # Change this for multilingual: \"english_phoneme\", \"spanish_phoneme\", \"english_chartokenizer\", \"german_chartokenizer\".. \n",
" )\n",
" data_samples.append(dataset_sample)\n",
" \n",
Expand Down

0 comments on commit 3f69088

Please sign in to comment.