Skip to content

Conversation

@mikasenghaas
Copy link
Member

@mikasenghaas mikasenghaas commented Aug 26, 2025

There are two problems on current main that this PR addresses:

  1. We always initialize the data buffer from scratch, which may become a problem when resuming a training from stateful data buffer whose problem or rollout sampling depends on metadata accumulated during a run. For example, if for DifficultyPoolBuffer pools we want to keep the information of which sample belongs to which difficulty pool across restarts
  2. We may want to do some offline filtering in between restarts (e.g. multi-stage RL). One example of this is that we filter out samples with average reward 1.0 for a subsequent start.

Changes in this PR:

  • Add serialization logic to any Buffer class, i.e. a method load which loads a HF dataset from a specified path and then initializes the buffer in-place and a method save which saves the buffer state as a HF dataset. To do so, it serializes the the metadata and rollout buffer into json strings and writes them as dataset columns.
  • Pass a Buffer object into the orchestrator's CheckpointManager to save and load the buffer in-place, following the general pattern of checkpointing across training
  • Adds a flag --config.buffer.from-scratch which specifies whether to attempt to load metadata and rollout columns from the dataset provided to the environment. Defaults to True, which initializes metadata and rollout from scratch.

These changes solve both problems:

  1. We can resume the orchestrator while retaining the buffer state with no changes to orchestrator or config API, i.e. only specify --ckpt to checkpoint the progress and buffer on the orchestrator, and --ckpt.resume-step to resume from a given checkpoint step (Examples 1,2,3)
  2. We can take a buffer checkpoint, perform offline filtering on it (e.g. by filtering on the avg. reward in the metadata column) and then start a new run by passing the updated dataset as an environment argument. This pattern should be generic enough to allow for any data transformations/ processing steps we may wish to do in between restarts, given that we write the correct metadata during the run.

TODO:

  • Validate resuming logic with simple (no metadata or rollout buffer is written)
  • Validate resuming logic with online-difficulty (metadata of rewards is written)
  • Validate resuming logic with difficulty-pool (rollout buffer is written)
  • Validate initializing with metadata/ rollout columns

Simple Buffer

uv run rl  \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type simple\
  --max-steps 5 \
  --ckpt
uv run rl  \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type simple \
  --max-steps 10 \
  --ckpt.resume-step 5

Confirmed that the buffer checkpoint written at step 5 and 10 does not include any metadata or rollouts.

Screenshot 2025-08-26 at 11 02 50 AM

Online Difficulty

uv run rl   \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type online-difficulty \
  --max-steps 5 \
  --ckpt
uv run rl  \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type online-difficulty \
  --max-steps 10 \
  --ckpt.resume-step 5

Confirmed that the buffer checkpoint written at step 5 and 10 does not include any rollouts but collects metadata on the most recent reward achieved.

Screenshot 2025-08-26 at 11 05 31 AM

Difficulty Pool

uv run rl   \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type difficulty-pool \
  --max-steps 5 \
  --ckpt
uv run rl   \
  --trainer @ configs/reverse_text/train.toml   \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type difficulty-pool  \
  --max-steps 10 \
  --ckpt.resume-step 5

Confirmed that the buffer checkpoint written at step 5 and 10 does not include any rollouts (we purge them) but collects metadata on the difficulty, and that the distribution over pools changes.

Screenshot 2025-08-26 at 11 08 42 AM

Manual filtering

Create a new dataset (pushed to hub as mikasenghaas/Reverse-Text-RL-PR-839 (HF) as an example)

import json
import random
from datasets import load_dataset, Dataset

dataset = load_dataset("PrimeIntellect/Reverse-Text-RL", split="train")
assert isinstance(dataset, Dataset)
dataset = dataset.add_column("metadata", [json.dumps({"difficulty": random.choice(["easy", "normal"])}) for _ in range(len(dataset))], new_fingerprint="metadata")
dataset = dataset.add_column("rollouts", [json.dumps([]) for _ in range(len(dataset))], new_fingerprint="rollout")

dataset.push_to_hub("mikasenghaas/Reverse-Text-RL-PR-839")

Make sure that your environment exposes dataset_name (optionally also dataset_split) as done in this PR for reverse-text and that you unset the from_scratch field in the buffer config to correctly load the metadata and rollouts fields.

uv run rl  \
  --trainer @ configs/reverse_text/train.toml  \
  --orchestrator @ configs/reverse_text/orch.toml \
  --log.level debug \
  --orchestrator.buffer.type difficulty-pool \
  --no-orchestrator.buffer.from-scratch \
  --orchestrator.environment.args '{"dataset_name": "mikasenghaas/Reverse-Text-RL-PR-839"}' \
  --max-steps 10

Again, we check that the buffer was correctly initialized from this new dataset because the distribution over pools shows that ~half of the examples are in the easy pool which as initialized.

Screenshot 2025-08-26 at 11 52 32 AM

GitHub Issue: #748
Linear Issue: Resolves PRIMERL-22

@mikasenghaas mikasenghaas force-pushed the mika/feat/buffer-ckpt+offline-filter branch from bbb4f00 to 2957673 Compare August 26, 2025 09:57
@mikasenghaas mikasenghaas marked this pull request as ready for review August 26, 2025 10:14
@mikasenghaas
Copy link
Member Author

Many thanks to @semioz for pioneering serializing the buffer as a HF dataset in #823, made my life much simpler. Also sorry for taking over so abruptly, but we needed this in ASAP for an internal run and there were some new requirements, so this was the quicker way:) You get your contribution anyways

Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@mikasenghaas mikasenghaas force-pushed the mika/feat/buffer-ckpt+offline-filter branch from 5b8b713 to b5620d9 Compare August 26, 2025 17:17
@mikasenghaas mikasenghaas merged commit 6829b2d into main Aug 26, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants