fix RLHF llama rewarding modeling backward issue#612
Conversation
|
meet the issue during PPO rewarding model DDP finetune enabling. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Reading
I would expect the fix to simply modify this operation so that changes are not done inplace. Or is that not possible? |
|
It looks good to me but I'll wait for my Gaudi2 instance to be fixed before merging to check if training and inference throughputs are not impacted. |
|
Any update by your side, have you got your gaudi2 card? @regisss |
Yes, I'll check this PR today or tomorrow |
|
thanks, glad to hear that. |
|
Hmm I see a 3% throughput regression on Llama2-70b generation with this fix. |
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
5208139 to
2c50ae9
Compare
done. do you know the reason of the regression, does it break something like static shape? |
Thanks, I'm going to try it. |
There was a problem hiding this comment.
After taking a closer look at it and reading this thread, I think the best here is to define self.norm_factor as a non-tensor float:
- a variable defined with
register_bufferwill be moved to the target device when callingmodel.to(device), which is not the case if it is defined as a regular tensor - we have
persistent=False, which means that this variable will not be part of the state dict anyway (same as defining it as a float)
The current implementation with torch.tensor leads to a tiny speed regression because the tensor will always be on CPU, even after calling model.to(device). We could easily live with that, but it seems that just switching from a float tensor to a regular float gives a small speedup for the exact same behavior so let's do it.
@sywangyi Can you just check that your script still works with the change I'm suggesting?
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
works by myside. |
What does this PR do?
Fixes # (issue)
Before submitting