Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
liger loss isn't compatible with ref precomputing right? If so we could add a warning or an error. |
|
This PR needs to use _FSDPForwardRedirection or another solution to work with FSDP correctly |
|
@VProv, at the moment, I was having issues getting the same outputs/metrics with and without liger in the trainer. |
What setup are you using? |
|
Hi, I am working on fixing the output/metrics issue. |
|
@kashif @qgallouedec can you please review the following PR which fixes the output/metrics issue? Thanks :) |
|
thanks @hanbyul-kim for the report |
|
@kashif just wanted to circle back and see if we can merge this now? We wanted to try it out internally at Linkedin. |
trl/trainer/dpo_trainer.py
Outdated
| import wandb | ||
|
|
||
|
|
||
| def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor: |
|
|
||
|
|
||
| def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: | ||
| """Shift input ids one token to the right, and pad with pad_token_id""" |
There was a problem hiding this comment.
this docstring ain't accurate I think
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Vaibhav Jindal <32337828+vaibhavjindal@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Vaibhav Jindal <32337828+vaibhavjindal@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co> Co-authored-by: Vaibhav Jindal <32337828+vaibhavjindal@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
What does this PR do?
Add support for Liger-kernel losses for the DPO Kernel
Needs: linkedin/Liger-Kernel#521
Peft support: #3065