-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Make gather_for_metrics
usage more strict
#315
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I left one comment if you could address it when free.
trlx/utils/modeling.py
Outdated
@@ -280,6 +281,24 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]: | |||
return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() | |||
|
|||
|
|||
def gather_for_metrics(tensor, expected_number, batch_size, length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to work in our case because the tensor has expected_number == dataset_size
but if you loop through the eval dataloader and collect metrics with this, the last batch might contain duplicate entries from exhausted ranks when world size, dataset size, and batch size aren't aligned. Maybe we should leave a warning comment so that in the future we're aware to manually truncate them in such loops? One might expect it to have the same behavior as accelerate's function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you describe is exactly what this PR amends (or at least intends to 🥲)
Let's use very unaligned hyperparameters batch_size=12, world_size=3, len(eval_prompts)=109
with which running on main
accelerate launch --num_processes 3 --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py
would give only 1 eval sample – https://wandb.ai/sorry/trlx/runs/vh3ayv6b
(to reproduce you can pull main...1-eval-sample or go to main
and set those manually)
But with the new gather_for_metrics
it's back to the expected 109 – https://wandb.ai/sorry/trlx/runs/c6agq89q/
let's also check whether prompts weren't perturbed (because it's ambiguous in this example) by setting:
eval_prompts = list(map(str, range(109))),
https://wandb.ai/sorry/trlx/runs/c6agq89q
(On a similar note: can one add tests that rely on multiprocessing but would still be run on CI?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, this definitely fixes our specific eval code but if you were to use this to collect metrics in a dataloader loop (assuming, for example, when one big tensor doesn't fit in RAM or wherever) it might still duplicate entries. Repro the behavior with this gist https://gist.github.com/jon-tow/304efadc9fce470d7b4f7212d5cfcf18 (sorry I couldn't figure out how to simulate multi-process on google colab lol) I could be misunderstanding things so let me know :)
Re CI with tests on multi-process functions: I'm not sure 😅 We could always write the test and have it around to at least run locally, then skip the test case in CI. I'll get back to you on that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great code snippet Jon, thanks! I see now, you are correct, accelerate.gather_for_metrics
and gather_for_metrics
are two slightly different functions with different usages. To not introduce any additional code, it's possible to refactor existing to resolve the issue instead: accelerate.gather_for_metrics
now happens for each batch. Additionally elements of the batch are now padded to batch elements' max_length
instead of global seq_length
. And they are collected as lists, to avoid one big allocation (possibly bigger than RAM), since they don't have to be tensors afterwards.
109 example: https://wandb.ai/sorry/trlx/runs/t8nc849j
comparision with main
: https://wandb.ai/sorry/trlx/reports/Make-gather_for_metrics-more-strict-315--VmlldzozNTkyMTUy
Re CI with tests on multi-process functions: I guess there is no rush for it as of now 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that's a very clean refactor! For some reason, this breaks seq2seq
evaluation with a RuntimeError
from non-contiguous tensors in the underlying gather call 🤔 Can you reproduce it on your end? Full traceback and config run here https://gist.github.com/jon-tow/8154845bd05cea3946e35b4a7f89a88c
I think once this is cleared up we'll be good 🤞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting it in my stead! The fix is easy enough https://wandb.ai/sorry/trlx/reports/Make-gather_for_metrics-usage-more-strict-315--VmlldzozNTk2ODEz
gather_for_metrics
more strictgather_for_metrics
usage more strict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! The gather-hanging issue I brought up at the meeting seems related to num_rollouts
config as you mentioned, so ignore :)
This PR makes
accelerate.gather_for_metrics
happen for each batch. Additionally elements of the batch are now padded to batch elements' max_length instead of global seq_length. And they are collected as lists, to avoid one big allocation (possibly bigger than RAM), since they don't have to be tensors afterwards.https://wandb.ai/sorry/trlx/reports/Make-gather_for_metrics-more-strict-315--VmlldzozNTkyMTUy