-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve documentation/comments on the random walk example #208
Improve documentation/comments on the random walk example #208
Conversation
4c79c09
to
acade3e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this together @alan-cooney. I had some trouble running things locally. Can you fix the requested changes and provide a wandb
report for both examples to ensure everything is working as expected? Thanks!
41fe413
to
f8c3bda
Compare
Thanks for the quick review @jon-tow PPO results - https://wandb.ai/alancooney/trlx/runs/eo1vxg53 Typings fixBy the way, I had to fix the typings in trlx/trlx.py so that they work with the approach you prefer here. This fix is needed in any case, as the typings were incorrect (the metric function only takes samples - it's just the reward function that also takes prompts and outputs): trlx/trlx/trainer/accelerate_base_trainer.py Line 357 in 84a0711
However, I'm happy to move this fix to a different PR if you want to keep the commit history clean (it's a small change, but it doesn't really belong here). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I had to fix the typings in trlx/trlx.py so that they work with the approach you prefer here. This fix is needed in any case, as the typings were incorrect (the metric function only takes samples - it's just the reward function that also takes prompts and outputs):
trlx/trlx/trainer/accelerate_base_trainer.py
Line 357 in 84a0711
metrics = self.metric_fn(str_samples) However, I'm happy to move this fix to a different PR if you want to keep the commit history clean (it's a small change, but it doesn't really belong here).
Oh huh; this slipped under the radar. Thanks for the find 🙏 metric_fn
should match the reward_fn
signature and be called as
metrics = self.metric_fn(
samples=str_samples,
prompts=str_prompts,
outputs=str_outputs,
)
on the line that you've highlighted. It's a small enough change that we can squeeze it into this PR. Let me know if that's alright with you! Once that's done we should be good to merge 👍
Makes sense! All done, and I've checked the runs work as well: PPO - https://wandb.ai/alancooney/trlx/runs/agtv55vb |
Have to say this is a great work! However it's rather peculiar that I cannot reproduce your PPO wandb run from this branch despite my run being identical as from the https://wandb.ai/sorry/trlx/reports/random_walks_document-v-main--VmlldzozMzkyNjE1 |
3378119
to
c8a776a
Compare
Sure, this seems plausible. I get the same results on main as on this branch, so I think it's all fine. Main - https://wandb.ai/alancooney/trlx/runs/ixhfy261 As I understand it, CUDA + different hardware can cause different results with the same random seed. But in terms of environments, I'm using this docker image #196 and then Note: sorry about the force push - committed to the wrong branch by mistake, so I've put this branch back to where you both reviewed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can also reproduce the runs from main
:
Thanks a bunch, @alan-cooney!
Makes this example more readable: