Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Chunked DPO Loss Kernel #378

Merged
merged 6 commits into from
Nov 15, 2024

Conversation

austin362667
Copy link
Contributor

@austin362667 austin362667 commented Nov 13, 2024

Summary

Add support for a fused, torch-compiled, and chunked DPO (Direct Preference Optimization) loss kernel, as requested in #371.
This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s.

DPO Loss Formulation

In a reference setting:

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$

Corresponds to:

# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# policy_chosen_logps - ref_chosen_logps - policy_rejected_logps + ref_rejected_logps
logits_diff = (chosen_advantages - rejected_advantages) * beta

# DPO loss
losses = -F.logsigmoid(logits_diff)

Testing Done

dpo_loss_memory
dpo_loss_speed

  • Hardware Type: NVIDIA L40S (48G)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

run_benchmarks,
)

from liger_kernel.alignment.dpo_loss import HF_DPO_Loss, LigerFusedLinearDPOFunction
Copy link
Contributor Author

@austin362667 austin362667 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I use HF DPO impl here in benchmarking for function reusability purpose? Or write another naive impl in pure torch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF DPO should be fine

return grad_input, grad_weight, None, grad_bias, None, None, None


class HF_DPO_Loss:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I move this HF impl to file test_dpo_loss.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, since HF impl is only for testing purpose

@lancerts
Copy link
Collaborator

can we modify logits_diff = (chosen_logps - rejected_logps) / beta
to
logits_diff = (chosen_logps - rejected_logps) * beta to align with the convention in paper as well as the trl implementation here

Copy link
Collaborator

@pramodith pramodith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a FYI, I think we should wait until @shivam15s pushes a generic/inheritable class that handles all the chunking and other repetitive logic common to different loss functions, before pushing new loss functions.

@shivam15s
Copy link
Collaborator

Great work @austin362667 ! The additional summing of NLL loss is going to be useful for IRPO loss as well :). I'll be creating a simple base class which adds the boilerplate code (backward/torch compile logic) that you can inherit from, as @pramodith mentioned

@austin362667 austin362667 marked this pull request as ready for review November 14, 2024 07:56
@austin362667
Copy link
Contributor Author

Issue addressed. Thanks @Tcc0403 @lancerts @pramodith @shivam15s and @ByronHsu for review!

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 14, 2024

I think we should make chunked_loss functions nn.Module (like flce and fljsd) for users? same for orpo? cc @shivam15s @ByronHsu

@ByronHsu
Copy link
Collaborator

@Tcc0403 that is the plan!

Signed-off-by: Austin Liu <[email protected]>

Fix benchmark script
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@ByronHsu ByronHsu merged commit 1aa3d83 into linkedin:main Nov 15, 2024
1 of 3 checks passed
@austin362667 austin362667 mentioned this pull request Nov 15, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants