-
Notifications
You must be signed in to change notification settings - Fork 130
Implement serialization to make buffer checkpoint-compliant #823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
sorry, we gon fix the w&b api key so the integration test runs proper. |
|
@mikasenghaas was about to test but should i also convert my torch.save's to safetensors as you mentioned in this issue? #821 |
|
@semioz nope, wait with this for now. we will rewrite from |
mikasenghaas
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of comments, but on a good way! nice job
src/prime_rl/orchestrator/buffer.py
Outdated
| torch.save(self.rollout_buffer, path / "rollout_buffer.pt") | ||
| torch.save(self.metadata, path / "metadata.pt") | ||
|
|
||
| self.dataset.save_to_disk(path / "buffer_dataset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, i can see how this is the easiest way to implement this but i wonder if somehow we could save/load from a single dataset instance. eg. metadata could easily be a column in the dataset. rollout buffer is a bit trickier tho.. i think it's fine like this for now but it feels like a single hf dataset might be the most clean way of serializing
src/prime_rl/orchestrator/buffer.py
Outdated
| return sampled_rollouts | ||
|
|
||
|
|
||
| def load_buffer(path: Path, buffer_config: DataBufferConfigType) -> Buffer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this function in-place like our other checkpoint logic?
src/prime_rl/orchestrator/ckpt.py
Outdated
| buffer_path = step_path / "buffer" | ||
| buffer.save(buffer_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be inside _save_to_path for async checkpointing to work properly
| # Load buffer from checkpoint | ||
| if config.ckpt.resume_buffer_from_checkpoint: | ||
| logger.info("Resuming buffer from checkpoint") | ||
| buffer_path = ckpt_manager.get_ckpt_path(config.ckpt.resume_step) / "buffer" | ||
| buffer = load_buffer(str(buffer_path), config.buffer) | ||
| else: | ||
| logger.info("Initializing buffer from scratch") | ||
| buffer = setup_buffer(dataset, config.buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we make the suggested changes to the buffer load function and the checkpoint, there should be very minimal changes to the orchestrator code
|
@mikasenghaas thanks. converted the load/save process to fully hf via columns and moved the resume option to buffer configs(lmk if you guys don't wanna make it configurable at all) also other changes you requested. ready for your review again i guess. |
|
Closing as continued in #839 |
Resuming data buffer properly from a checkpoint. When saving, it stores the rollout buffer, metadata, and the dataset to disk. Whether to restore from a checkpoint or start fresh is controlled by a simple config flag just like we had with resume_step.
GitHub Issue: #748