Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BoN for training and eval #528

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open

Implement BoN for training and eval #528

wants to merge 40 commits into from

Conversation

Dahoas
Copy link
Collaborator

@Dahoas Dahoas commented Jul 18, 2023

No description provided.

trlx/trainer/accelerate_ppo_trainer.py Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this example got here by inertia from the previous PR

trlx/trainer/accelerate_base_trainer.py Show resolved Hide resolved
torch.distributed.scatter(scores, all_scores)
else:
scores = all_scores[0].clone().detach()
# Best-of-N Sampling.
scores_mask = scores != -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to merge changes from your last PR in

Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks slick!

(One thing I've noticed that while score increases much faster on the training set with increase of num_return_sequences this doesn't necessarly yield better score on the test set. Do you have perhaps an example or parameter setting where it does so?)

Screenshot 2023-08-11 at 15 30 46
Screenshot 2023-08-11 at 15 35 19

trlx/models/modeling_ppo.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_base_trainer.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_ppo_trainer.py Show resolved Hide resolved
@@ -520,3 +585,17 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

# Push samples and rewards to trainer's rollout storage
self.push_to_store(ppo_rl_elements)

@staticmethod
def get_topk_indices(input_tensor, window_size: int, k: int, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe docstring should be added specifying that this isn't the same as regular topk but rather a topk overw window_size

@Dahoas
Copy link
Collaborator Author

Dahoas commented Aug 21, 2023

Looks slick!

(One thing I've noticed that while score increases much faster on the training set with increase of num_return_sequences this doesn't necessarly yield better score on the test set. Do you have perhaps an example or parameter setting where it does so?)

Screenshot 2023-08-11 at 15 30 46 Screenshot 2023-08-11 at 15 35 19

Good point, the benefit of BoN trainings seems to be problem dependent. I've seen the most benefit during training on problems where the model has a low pass@1 score.

@Dahoas
Copy link
Collaborator Author

Dahoas commented Aug 28, 2023

@maxreciprocate If you're happy with this do you want to merge today?

@maxreciprocate
Copy link
Collaborator

maxreciprocate commented Sep 1, 2023

@Dahoas There are some run differences when using the default config without BoN sampling, most notably for the randomwalks case:
https://wandb.ai/sorry/trlx-references/reports/BoN-v-main--Vmlldzo1MjkwMzA5
Probably some minor implementation detail, have to recheck

@Dahoas
Copy link
Collaborator Author

Dahoas commented Sep 4, 2023

@Dahoas There are some run differences when using the default config without BoN sampling, most notably for the randomwalks case: https://wandb.ai/sorry/trlx-references/reports/BoN-v-main--Vmlldzo1MjkwMzA5 Probably some minor implementation detail, have to recheck

Let me look into why.

@maxreciprocate
Copy link
Collaborator

@Dahoas Not sure if that's the issue however, see: https://wandb.ai/sorry/trlx/reports/Difference-due-to-the-change-in-base_trainer-decode--Vmlldzo1MzE2OTg4 (+ some non-determinism)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants