Skip to content

Commit 1cab64b

Browse files
viettmabsayakpaul
andauthored
Update train_diffusion_dpo.py (#6754)
* Update train_diffusion_dpo.py Address #6702 * Update train_diffusion_dpo_sdxl.py * Empty-Commit --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 8d7dc85 commit 1cab64b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/research_projects/diffusion_dpo/train_diffusion_dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def collate_fn(examples):
860860
# Final loss.
861861
scale_term = -0.5 * args.beta_dpo
862862
inside_term = scale_term * (model_diff - ref_diff)
863-
loss = -1 * F.logsigmoid(inside_term.mean())
863+
loss = -1 * F.logsigmoid(inside_term).mean()
864864

865865
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
866866
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)

examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
975975
# Final loss.
976976
scale_term = -0.5 * args.beta_dpo
977977
inside_term = scale_term * (model_diff - ref_diff)
978-
loss = -1 * F.logsigmoid(inside_term.mean())
978+
loss = -1 * F.logsigmoid(inside_term).mean()
979979

980980
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
981981
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)

0 commit comments

Comments
 (0)