-
Notifications
You must be signed in to change notification settings - Fork 128
Buffer checkpointing and offline data filtering #839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
bbb4f00 to
2957673
Compare
|
Many thanks to @semioz for pioneering serializing the buffer as a HF dataset in #823, made my life much simpler. Also sorry for taking over so abruptly, but we needed this in ASAP for an internal run and there were some new requirements, so this was the quicker way:) You get your contribution anyways |
samsja
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
5b8b713 to
b5620d9
Compare
There are two problems on current main that this PR addresses:
DifficultyPoolBufferpools we want to keep the information of which sample belongs to which difficulty pool across restartsChanges in this PR:
Bufferclass, i.e. a methodloadwhich loads a HF dataset from a specified path and then initializes the buffer in-place and a methodsavewhich saves the buffer state as a HF dataset. To do so, it serializes the the metadata and rollout buffer intojsonstrings and writes them as dataset columns.Bufferobject into the orchestrator'sCheckpointManagertosaveandloadthe buffer in-place, following the general pattern of checkpointing across training--config.buffer.from-scratchwhich specifies whether to attempt to loadmetadataandrolloutcolumns from the dataset provided to the environment. Defaults to True, which initializesmetadataandrolloutfrom scratch.These changes solve both problems:
--ckptto checkpoint the progress and buffer on the orchestrator, and--ckpt.resume-stepto resume from a given checkpoint step (Examples 1,2,3)TODO:
simple(no metadata or rollout buffer is written)online-difficulty(metadata of rewards is written)difficulty-pool(rollout buffer is written)Simple Buffer
Confirmed that the buffer checkpoint written at step 5 and 10 does not include any metadata or rollouts.
Online Difficulty
Confirmed that the buffer checkpoint written at step 5 and 10 does not include any rollouts but collects metadata on the most recent reward achieved.
Difficulty Pool
Confirmed that the buffer checkpoint written at step 5 and 10 does not include any rollouts (we purge them) but collects metadata on the difficulty, and that the distribution over pools changes.
Manual filtering
Create a new dataset (pushed to hub as
mikasenghaas/Reverse-Text-RL-PR-839(HF) as an example)Make sure that your environment exposes
dataset_name(optionally alsodataset_split) as done in this PR forreverse-textand that you unset thefrom_scratchfield in the buffer config to correctly load themetadataandrolloutsfields.uv run rl \ --trainer @ configs/reverse_text/train.toml \ --orchestrator @ configs/reverse_text/orch.toml \ --log.level debug \ --orchestrator.buffer.type difficulty-pool \ --no-orchestrator.buffer.from-scratch \ --orchestrator.environment.args '{"dataset_name": "mikasenghaas/Reverse-Text-RL-PR-839"}' \ --max-steps 10Again, we check that the buffer was correctly initialized from this new dataset because the distribution over pools shows that ~half of the examples are in the
easypool which as initialized.GitHub Issue: #748
Linear Issue: Resolves PRIMERL-22