Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Make gather_for_metrics usage more strict #315

Merged
merged 3 commits into from
Feb 20, 2023

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Feb 17, 2023

This PR makes accelerate.gather_for_metrics happen for each batch. Additionally elements of the batch are now padded to batch elements' max_length instead of global seq_length. And they are collected as lists, to avoid one big allocation (possibly bigger than RAM), since they don't have to be tensors afterwards.

https://wandb.ai/sorry/trlx/reports/Make-gather_for_metrics-more-strict-315--VmlldzozNTkyMTUy

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I left one comment if you could address it when free.

@@ -280,6 +281,24 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt()


def gather_for_metrics(tensor, expected_number, batch_size, length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to work in our case because the tensor has expected_number == dataset_size but if you loop through the eval dataloader and collect metrics with this, the last batch might contain duplicate entries from exhausted ranks when world size, dataset size, and batch size aren't aligned. Maybe we should leave a warning comment so that in the future we're aware to manually truncate them in such loops? One might expect it to have the same behavior as accelerate's function here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you describe is exactly what this PR amends (or at least intends to 🥲)

Let's use very unaligned hyperparameters batch_size=12, world_size=3, len(eval_prompts)=109 with which running on main
accelerate launch --num_processes 3 --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py
would give only 1 eval sample – https://wandb.ai/sorry/trlx/runs/vh3ayv6b
(to reproduce you can pull main...1-eval-sample or go to main and set those manually)

But with the new gather_for_metrics it's back to the expected 109 – https://wandb.ai/sorry/trlx/runs/c6agq89q/
let's also check whether prompts weren't perturbed (because it's ambiguous in this example) by setting:

eval_prompts = list(map(str, range(109))),

https://wandb.ai/sorry/trlx/runs/c6agq89q

(On a similar note: can one add tests that rely on multiprocessing but would still be run on CI?)

Copy link
Collaborator

@jon-tow jon-tow Feb 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, this definitely fixes our specific eval code but if you were to use this to collect metrics in a dataloader loop (assuming, for example, when one big tensor doesn't fit in RAM or wherever) it might still duplicate entries. Repro the behavior with this gist https://gist.github.com/jon-tow/304efadc9fce470d7b4f7212d5cfcf18 (sorry I couldn't figure out how to simulate multi-process on google colab lol) I could be misunderstanding things so let me know :)

Re CI with tests on multi-process functions: I'm not sure 😅 We could always write the test and have it around to at least run locally, then skip the test case in CI. I'll get back to you on that!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great code snippet Jon, thanks! I see now, you are correct, accelerate.gather_for_metrics and gather_for_metrics are two slightly different functions with different usages. To not introduce any additional code, it's possible to refactor existing to resolve the issue instead: accelerate.gather_for_metrics now happens for each batch. Additionally elements of the batch are now padded to batch elements' max_length instead of global seq_length. And they are collected as lists, to avoid one big allocation (possibly bigger than RAM), since they don't have to be tensors afterwards.

109 example: https://wandb.ai/sorry/trlx/runs/t8nc849j
comparision with main: https://wandb.ai/sorry/trlx/reports/Make-gather_for_metrics-more-strict-315--VmlldzozNTkyMTUy

Re CI with tests on multi-process functions: I guess there is no rush for it as of now 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a very clean refactor! For some reason, this breaks seq2seq evaluation with a RuntimeError from non-contiguous tensors in the underlying gather call 🤔 Can you reproduce it on your end? Full traceback and config run here https://gist.github.com/jon-tow/8154845bd05cea3946e35b4a7f89a88c

I think once this is cleared up we'll be good 🤞

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxreciprocate maxreciprocate changed the title [fix] Make gather_for_metrics more strict [fix] Make gather_for_metrics usage more strict Feb 20, 2023
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! The gather-hanging issue I brought up at the meeting seems related to num_rollouts config as you mentioned, so ignore :)

@jon-tow jon-tow merged commit 3396bf1 into main Feb 20, 2023
@Jiaxin-Wen Jiaxin-Wen mentioned this pull request Feb 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants