Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Gather experience samples #305

Merged
merged 2 commits into from
Feb 13, 2023
Merged

[feat] Gather experience samples #305

merged 2 commits into from
Feb 13, 2023

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Feb 10, 2023

This PR lets PPO trainer to gather all experience samples on the main rank and to do a single joint reward_fn call per each rollout.

This enables hosting a single reward model on the same machine as the main rank, when previously every process had to had access to the reward model. Also it enables deliberate micro-batching for the reward model, unlike in the case when each process tries to infer reward model (for example deployed on Triton server) with its own small chunk_size number of samples usually bottle-necking whole training.

Main goal here is to give up the current dependency on the Triton server and to enable simple and self-contained 7+1(RM) or 15+1(RM) setups (can finally move to cw)

https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305--VmlldzozNTQ0OTkz
https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305---VmlldzozNTMxMzc3

Side-note: every reference remains the same except for sentiments, after some debugging I've noticed that even doing multiple redundant passes of sentiment pipeline on the same data apparently changes rng or otherwise slightly influences the run, as in doing:

scores = torch.tensor(
    self.reward_fn(
        samples=str_samples,
        prompts=str_prompts,
        outputs=str_outputs,
    ),
    dtype=torch.float,
).to(device)

scores = torch.tensor(
    self.reward_fn(
        samples=str_samples,
        prompts=str_prompts,
        outputs=str_outputs,
    ),
    dtype=torch.float,
).to(device)

scores = torch.tensor(
    self.reward_fn(
        samples=str_samples,
        prompts=str_prompts,
        outputs=str_outputs,
    ),
    dtype=torch.float,
).to(device)

@cat-state
Copy link
Collaborator

cat-state commented Feb 10, 2023

re the sidenote: Maybe we need to call .eval() somewhere, if its using dropout? Or it could be nondeterminism from the kernels.

@maxreciprocate
Copy link
Collaborator Author

Oh I've forgot to mention that scores from the sentiment pipeline remain the same, as doing so passes:

scores = torch.tensor(
    self.reward_fn(
        samples=str_samples,
        prompts=str_prompts,
        outputs=str_outputs,
    ),
    dtype=torch.float,
).to(device)

scores2 = torch.tensor(
    self.reward_fn(
        samples=str_samples,
        prompts=str_prompts,
        outputs=str_outputs,
    ),
    dtype=torch.float,
).to(device)

assert torch.all(scores == scores2)

However the number of calls to reward_fn is divided by num_processes as compared to the reference, so I suspect that's reason behind the difference

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Glad to see you push us to the CW side hehe Left one comment if you could address before merging 🙏


scores = torch.empty(len(samples), device=device)
torch.distributed.scatter(scores, all_scores)

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples)
Copy link
Collaborator

@jon-tow jon-tow Feb 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we're decoding tokens twice now; which seems to not slow down anything from the system plots? (if anything it's probably from the gather overhead). Does decode mutate the inputs in any way such that the line below will be different from a single call to decode? Does not seem to

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples)
# Pad the sample outputs
outputs = self.tokenizer(str_outputs).input_ids
outputs = list(map(torch.LongTensor, outputs))

Copy link
Collaborator Author

@maxreciprocate maxreciprocate Feb 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid repetition I could scatter samples as well, but I don't think it's worth it since the second decode (which would be second only on the main rank) is needed just to strip stop_sequences, so it's basically just an allocation. Runtime in make_experience is still dominated by generate and reward_fn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's fine! (I actually meant to approve this last night but forgot to 😅)

@jon-tow
Copy link
Collaborator

jon-tow commented Feb 11, 2023

re the sidenote: Maybe we need to call .eval() somewhere, if its using dropout? Or it could be nondeterminism from the kernels.

The sentiments pipeline is in eval-mode so at least it's not any stochasticity from dropout in that distilbert RM (sentiment_fn.model.training == False) 🤔

Also, the returns/values stats sort of explode for sentiments on main - it's interesting that the repeat call smooths things out.

@LouisCastricato
Copy link
Contributor

This is great! Happy that we're finally merging something like this.

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Max!

@maxreciprocate
Copy link
Collaborator Author

maxreciprocate commented Feb 11, 2023

@jon-tow Even if the pipeline is in eval-mode, rng state is still accessed somewhere from within it, a simple test:

torch.manual_seed(1000)
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
>>> torch.rand(1)=tensor([0.3189])
>>> torch.rand(1)=tensor([0.6136])
>>> torch.rand(1)=tensor([0.4418])

torch.manual_seed(1000)
print(f'{torch.rand(1)=}')
print(f'{torch.rand(1)=}')
reward_fn(['1'])
print(f'{torch.rand(1)=}')
>>> torch.rand(1)=tensor([0.3189])
>>> torch.rand(1)=tensor([0.6136])
>>> torch.rand(1)=tensor([0.2724])

Since none of ranks except the main now access reward_fn, a difference surfaced. reward_fn in this case being:

def get_positive_score(scores):
    "Extract value associated with a positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]

sentiment_fn = pipeline(
    "sentiment-analysis",
    top_k=2,
    truncation=True,
    batch_size=256,
    device=device,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return sentiments

https://wandb.ai/sorry/trlx/reports/Gather-experience-samples-305--VmlldzozNTM1OTUz

Here's two additional reports: a single process run, and a run with changed reward_fn from the pipeline to something deterministic like:

def reward_fn(samples, **kwargs):
    return [len(s) for s in samples]

@jon-tow
Copy link
Collaborator

jon-tow commented Feb 12, 2023

@reciprocated yeah that's weird. If you call the underlying distilbert with some arbitrary inputs the RNG doesn't change https://colab.research.google.com/drive/1FCmeWEJGl5GAhikeUXXR5VrdHOt4S6WN?usp=sharing
pipeline is super opaque so I can't tell where it's happening 🤔 This PR should be fine especially since the t5 summarize reference is good.

@LouisCastricato
Copy link
Contributor

We should ping hugging face folks

@maxreciprocate
Copy link
Collaborator Author

@maxreciprocate maxreciprocate merged commit 724b618 into main Feb 13, 2023
@maxreciprocate maxreciprocate deleted the gather-exp-samples branch February 13, 2023 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants