Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions src/prime_rl/orchestrator/buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import random
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from pathlib import Path

from datasets import Dataset

Expand Down Expand Up @@ -96,6 +98,72 @@ def __init__(self, dataset: Dataset):
self.rollout_buffer: dict[int, list[Rollout]] = {}
self.metadata: dict[int, dict] = {problem_id: {} for problem_id in self.problem_ids}

def save(self, path: Path):
"""Saves the buffer state to a single HF dataset."""
path.parent.mkdir(parents=True, exist_ok=True)

# Strip stale columns if present before proceding.
dataset = self.dataset.remove_columns(
[c for c in ("metadata", "rollouts") if c in self.dataset.column_names]
)

# Serialize metadata and rollouts
dataset = dataset.add_column(
"metadata", [json.dumps(self.metadata.get(i, {})) for i in range(len(dataset))]
)
dataset = dataset.add_column(
"rollouts",
[json.dumps([asdict(r) for r in self.rollout_buffer.get(i, [])]) for i in range(len(dataset))],
)

dataset.save_to_disk(path)

def load(self, path: Path):
"""Loads the buffer state from a single HF dataset."""
try:
dataset = Dataset.load_from_disk(path)

# Deserialization of metadata and rollouts
self.metadata = {}
if "metadata" in dataset.column_names:
for i, m in enumerate(dataset["metadata"]):
if not m:
self.metadata[i] = {}
continue
try:
self.metadata[i] = json.loads(m)
except json.JSONDecodeError as e:
self.logger.warning(f"Failed to parse metadata row {i}: {e}")
self.metadata[i] = {}
else:
self.metadata = {i: {} for i in range(len(dataset))}

self.rollout_buffer = {}
if "rollouts" in dataset.column_names:
for i, serialized in enumerate(dataset["rollouts"]):
if not serialized:
continue
try:
rollout_dicts = json.loads(serialized)
if rollout_dicts:
self.rollout_buffer[i] = [Rollout(**r) for r in rollout_dicts]
except (json.JSONDecodeError, TypeError) as e:
self.logger.warning(f"Failed to parse rollouts row {i}: {e}")

self.dataset = dataset.remove_columns(
[c for c in ("metadata", "rollouts") if c in dataset.column_names]
)

self.problem_ids = list(range(len(self.dataset)))
self.problem_buffer = {
problem_id: problem
for problem_id, problem in zip(self.problem_ids, self.dataset)
}

except Exception as e:
self.logger.error(f"Failed to load buffer from {path}: {e}")
raise

@abstractmethod
def sample_problems(self, n: int) -> tuple[list[int], list[dict]]:
"""
Expand Down Expand Up @@ -417,4 +485,4 @@ def setup_buffer(dataset: Dataset, buffer_config: DataBufferConfigType) -> Buffe
elif buffer_config.type == "difficulty-pool":
return DifficultyPoolBuffer(dataset, buffer_config)
elif buffer_config.type == "online-difficulty":
return OnlineDifficultyBuffer(dataset, buffer_config)
return OnlineDifficultyBuffer(dataset, buffer_config)
18 changes: 16 additions & 2 deletions src/prime_rl/orchestrator/ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from prime_rl.orchestrator.buffer import Buffer
from prime_rl.orchestrator.config import CheckpointConfig
from prime_rl.utils.logger import get_logger
from prime_rl.utils.utils import get_ckpt_dir
Expand Down Expand Up @@ -38,7 +39,16 @@ def _get_step_path(self, step: int) -> Path:
def _get_ckpt_path(self, step: int) -> Path:
return self._get_step_path(step) / "orchestrator.pt"

def _save_to_path(self, ckpt_path: Path, ckpt_step: int, progress: RLProgress | SFTProgress):
def get_ckpt_path(self, step: int) -> Path:
return self._get_step_path(step)

def _save_to_path(
self,
ckpt_path: Path,
ckpt_step: int,
buffer: Buffer,
progress: RLProgress | SFTProgress,
):
self._logger.debug(f"Saving orchestrator checkpoint to {ckpt_path}")
start_time = time.time()

Expand All @@ -49,6 +59,9 @@ def _save_to_path(self, ckpt_path: Path, ckpt_step: int, progress: RLProgress |
with open(ckpt_path, "wb") as f:
torch.save(ckpt_state, f)

buffer_path = ckpt_path.parent / "buffer"
buffer.save(buffer_path)

# Append to list of saved steps
self.ckpt_steps.append(ckpt_step)

Expand Down Expand Up @@ -79,13 +92,14 @@ def load(self, progress: RLProgress | SFTProgress, step: int) -> None:
def save(
self,
progress: RLProgress | SFTProgress,
buffer: Buffer,
step: int,
) -> None:
"""Saves the full checkpoint state for a specified step."""
step_path = self._get_step_path(step)
step_path.mkdir(parents=True, exist_ok=True)
ckpt_path = self._get_ckpt_path(step)
self._save_to_path(ckpt_path, step, progress)
self._save_to_path(ckpt_path, step, buffer, progress)

def maybe_clean(self) -> None:
"""Deletes past orchestrator checkpoints beyond the most recent `keep` steps. No-op if `keep` is None."""
Expand Down
20 changes: 19 additions & 1 deletion src/prime_rl/orchestrator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,22 @@ class CheckpointConfig(BaseConfig):

class SimpleBufferConfig(BaseModel):
type: Literal["simple"] = "simple"
resume: Annotated[
bool,
Field(
description="Whether to resume the data buffer from the checkpoint.",
),
] = True


class DifficultyPoolBufferConfig(BaseModel):
type: Literal["difficulty-pool"] = "difficulty-pool"
resume: Annotated[
bool,
Field(
description="Whether to resume the data buffer from the checkpoint.",
),
] = True

difficulty_field: Annotated[
str | None,
Expand Down Expand Up @@ -283,6 +295,12 @@ class DifficultyPoolBufferConfig(BaseModel):

class OnlineDifficultyBufferConfig(BaseModel):
type: Literal["online-difficulty"] = "online-difficulty"
resume: Annotated[
bool,
Field(
description="Whether to resume the data buffer from the checkpoint.",
),
] = True

min_reward: Annotated[
float | None,
Expand Down Expand Up @@ -462,4 +480,4 @@ def auto_setup_bench(self):
if self.monitor.wandb:
self.monitor.wandb.log_extras = None

return self
return self
12 changes: 8 additions & 4 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ async def orchestrate(config: OrchestratorConfig):
vf_env = load_environment(config.environment.id, **config.environment.args)
dataset = vf_env.get_dataset(seed=config.seed)

# Setup buffer
logger.info(f"Setting up buffer ({config.buffer})")
buffer = setup_buffer(dataset, config.buffer)

if config.buffer.resume:
logger.info("Resuming buffer from checkpoint")
buffer_path = ckpt_manager.get_ckpt_path(config.ckpt.resume_step) / "buffer"
buffer.load(buffer_path)

# Iterate over dataset in batches
max_steps = config.max_steps or int(1e9)
logger.info(f"Starting orchestrator loop ({max_steps=}")
Expand All @@ -123,7 +127,7 @@ async def orchestrate(config: OrchestratorConfig):
):
logger.info(f"Saving checkpoint at step {progress.step}")
save_ckpt_start_time = time.time()
ckpt_manager.save(progress, step=progress.step)
ckpt_manager.save(progress, buffer, step=progress.step)
save_ckpt_time = time.time() - save_ckpt_start_time

# Maybe clean up old orchestrator checkpoints
Expand Down Expand Up @@ -489,7 +493,7 @@ async def orchestrate(config: OrchestratorConfig):
# Write final checkpoint
if ckpt_manager is not None:
logger.info("Writing final checkpoint")
ckpt_manager.save(progress, step=progress.step)
ckpt_manager.save(progress, buffer, step=progress.step)

logger.success("Orchestrator finished.")

Expand All @@ -506,4 +510,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading