-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Gather experience samples #305
Conversation
re the sidenote: Maybe we need to call |
Oh I've forgot to mention that scores from the sentiment pipeline remain the same, as doing so passes: scores = torch.tensor(
self.reward_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
),
dtype=torch.float,
).to(device)
scores2 = torch.tensor(
self.reward_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
),
dtype=torch.float,
).to(device)
assert torch.all(scores == scores2) However the number of calls to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Glad to see you push us to the CW side hehe Left one comment if you could address before merging 🙏
|
||
scores = torch.empty(len(samples), device=device) | ||
torch.distributed.scatter(scores, all_scores) | ||
|
||
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we're decoding tokens twice now; which seems to not slow down anything from the system plots? (if anything it's probably from the gather
overhead). Does Does not seem todecode
mutate the inputs in any way such that the line below will be different from a single call to decode?
trlx/trlx/trainer/accelerate_ppo_trainer.py
Lines 325 to 329 in 2a45c08
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples) | |
# Pad the sample outputs | |
outputs = self.tokenizer(str_outputs).input_ids | |
outputs = list(map(torch.LongTensor, outputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid repetition I could scatter samples as well, but I don't think it's worth it since the second decode
(which would be second only on the main rank) is needed just to strip stop_sequences
, so it's basically just an allocation. Runtime in make_experience
is still dominated by generate
and reward_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's fine! (I actually meant to approve this last night but forgot to 😅)
The sentiments pipeline is in eval-mode so at least it's not any stochasticity from dropout in that distilbert RM ( Also, the returns/values stats sort of explode for sentiments on main - it's interesting that the repeat call smooths things out. |
This is great! Happy that we're finally merging something like this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Max!
@jon-tow Even if the pipeline is in eval-mode, rng state is still accessed somewhere from within it, a simple test: torch.manual_seed(1000)
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
>>> torch.rand(1)=tensor([0.3189])
>>> torch.rand(1)=tensor([0.6136])
>>> torch.rand(1)=tensor([0.4418])
torch.manual_seed(1000)
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
reward_fn(['1'])
print(f'{torch.rand(1)=}')
>>> torch.rand(1)=tensor([0.3189])
>>> torch.rand(1)=tensor([0.6136])
>>> torch.rand(1)=tensor([0.2724]) Since none of ranks except the main now access def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
sentiment_fn = pipeline(
"sentiment-analysis",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)
def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305--VmlldzozNTM1OTUz Here's two additional reports: a single process run, and a run with changed def reward_fn(samples, **kwargs):
return [len(s) for s in samples] |
@reciprocated yeah that's weird. If you call the underlying distilbert with some arbitrary inputs the RNG doesn't change https://colab.research.google.com/drive/1FCmeWEJGl5GAhikeUXXR5VrdHOt4S6WN?usp=sharing |
We should ping hugging face folks |
This PR lets PPO trainer to gather all experience samples on the main rank and to do a single joint reward_fn call per each rollout.
This enables hosting a single reward model on the same machine as the main rank, when previously every process had to had access to the reward model. Also it enables deliberate micro-batching for the reward model, unlike in the case when each process tries to infer reward model (for example deployed on Triton server) with its own small chunk_size number of samples usually bottle-necking whole training.
Main goal here is to give up the current dependency on the Triton server and to enable simple and self-contained 7+1(RM) or 15+1(RM) setups (can finally move to cw)
https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305--VmlldzozNTQ0OTkz
https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305---VmlldzozNTMxMzc3
Side-note: every reference remains the same except for sentiments, after some debugging I've noticed that even doing multiple redundant passes of sentiment pipeline on the same data apparently changes rng or otherwise slightly influences the run, as in doing: