-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add T5 model #145
Add T5 model #145
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I see you're adding a bunch of new tasks in your pr (which is great!) but they should probably be separated out into other prs if possible
- Do you have a wandb you can share?
- I would suggest not freezing anything first (on a very small model with a single gpu) to make sure the algo is right
examples/ds_config_trlx_neoj.json
Outdated
@@ -0,0 +1,22 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll probably want to put this under configs when finished
examples/ds_config_trlx_t5.json
Outdated
@@ -0,0 +1,22 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here (move under configs)
examples/summarize_dataset.py
Outdated
@@ -0,0 +1,110 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually we'll want to put this dataset onto huggingface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Possibly) Relevant: https://huggingface.co/datasets/openai/summarize_from_feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I used this dataset before they were public to hf, but that was for the RLHF blog post not for this PR.
@Dahoas @LouisCastricato , this is an example FlanT5 for the CNN-Dailymail dataset but other charts quite weird. Please check when you have time. https://wandb.ai/pvduy/trlx/runs/8q3skf8p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! I've left some feedback to be addressed. Very excited about this 👍
Thank you for your great comment, I will follow up and fix that. |
Fixed PPO for T5 (https://wandb.ai/pvduy/trlx/runs/1n31fb6a). The fix for GPT-J still running on the OpenAI summarization dataset to check. Please review this @reciprocated @LouisCastricato @Dahoas |
examples/trlx_t5_summ_daily_cnn.py
Outdated
|
||
meteor = evaluate.load("meteor") | ||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This this is an example we probably should have lots of comments.
@@ -40,6 +40,7 @@ def __init__( | |||
|
|||
if not hasattr(self.trainer.model, "frozen_head"): | |||
self.ref_model = self.trainer.get_arch(self.trainer.config) | |||
self.ref_model.to(self.trainer.accelerator.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we verified this works? I recall accelerate freezing up if I started putting multiple models on gpu(though this could've just been the sentiment pipeline we were using for sentiments task)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Dahoas yes that works.
e.g. https://wandb.ai/pvduy/trlx/runs/2wgpt4im
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do agree with the larger model we should distribute multiple models on multiple gpus, but for this, I think we should keep it on GPU rather than CPU, they are super slow.
You are right @jon-tow , I am removing from this PR we can consider merging this PR today, some bugs with the current main branch should be fixed by this PR. cc @LouisCastricato |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving some tiny final comments and change requests 🙏
rs[-1] = scores[ix] | ||
rs = rewards[ix] | ||
if len(rs) == 0: | ||
rs = torch.tensor([0.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we penalize empty responses? Also do you know how it's possible to have those, except for when max_new_tokens == 0
🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a complete exception case, but I got it a few times when I ran PPO sentiments. @jon-tow you also faced with this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah Jon has said that he also experienced it. I wonder if the case of it is unknown it may be a symptom of some other bug elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reciprocated the only case I can think of that'd lead to an empty response is when len(query)
is larger than the generate
method's min_length
arg, which defaults to 10
, and the model so happens to output the eos_token
on its first sample. (Note that with causal models the min_length
constraint includes the length of the context (query) meaning it won't actually have an effect on the generations if the min condition is already met by the context size).
In such cases, I'm okay with penalizing empty responses as they're uninformative - so long as this is not a bug lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think it is a serious issue. We can just throw an error if this min_length thing comes up. I've never seen this in practice when I set min length correctly. (Perhaps we should add an extra parameter called min_new_length...? We should upstream to HF transformers though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving some notes for myself to address in the future.
all_tokens, attention_mask, position_ids = self.trainer.get_model_inputs( | ||
query_tensors.to(response_tensors.device), response_tensors | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This removed the need for:
def get_model_inputs( |
We need to remove it if unused before it becomes stale.
logprobs = logprobs.cpu() | ||
ref_logprobs = ref_logprobs.cpu() | ||
query_tensors = query_tensors.cpu() | ||
response_tensors = response_tensors.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these lines - these vars are already put on cpu
on the lines right before the if-statement
CNN/Dailly mail: https://wandb.ai/pvduy/trlx/runs/lx4iq23e
PPO fixed sentiments (GPT2): https://wandb.ai/pvduy/trlx/runs/2cyo46k4