-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement BoN for training and eval #528
base: main
Are you sure you want to change the base?
Conversation
examples/ppo_redemption.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this example got here by inertia from the previous PR
torch.distributed.scatter(scores, all_scores) | ||
else: | ||
scores = all_scores[0].clone().detach() | ||
# Best-of-N Sampling. | ||
scores_mask = scores != -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to merge changes from your last PR in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -520,3 +585,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq | |||
|
|||
# Push samples and rewards to trainer's rollout storage | |||
self.push_to_store(ppo_rl_elements) | |||
|
|||
@staticmethod | |||
def get_topk_indices(input_tensor, window_size: int, k: int, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe docstring should be added specifying that this isn't the same as regular topk
but rather a topk overw window_size
Good point, the benefit of BoN trainings seems to be problem dependent. I've seen the most benefit during training on problems where the model has a low pass@1 score. |
@maxreciprocate If you're happy with this do you want to merge today? |
@Dahoas There are some run differences when using the default config without BoN sampling, most notably for the randomwalks case: |
Let me look into why. |
@Dahoas Not sure if that's the issue however, see: https://wandb.ai/sorry/trlx/reports/Difference-due-to-the-change-in-base_trainer-decode--Vmlldzo1MzE2OTg4 (+ some non-determinism) |
No description provided.