diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker.py b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py new file mode 100644 index 000000000..dadc98b20 --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py @@ -0,0 +1,28 @@ +import hydra + +from rllm.data.dataset import DatasetRegistry +from rllm.experimental.unified_trainer import AgentTrainer +from rllm.rewards.countdown_reward import countdown_reward_fn +from rllm.workflows.simple_workflow import SimpleWorkflow + + +@hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None) +def main(config): + train_dataset = DatasetRegistry.load_dataset("countdown", "train") + test_dataset = DatasetRegistry.load_dataset("countdown", "test") + + trainer = AgentTrainer( + workflow_class=SimpleWorkflow, + workflow_args={ + "reward_function": countdown_reward_fn, + }, + config=config, + train_dataset=train_dataset, + val_dataset=test_dataset, + backend="tinker", + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh new file mode 100644 index 000000000..1f3e8586b --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh @@ -0,0 +1,36 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=1 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=true \ + rllm.async_training.mini_batch_size=32 \ + rllm.async_training.fwd_bwd_group_size=8 \ + rllm.async_training.staleness_threshold=0.5 \ + rllm.async_training.trigger_parameter_sync_step=1 \ + rllm.async_training.partial_rollout=true \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh new file mode 100644 index 000000000..a9d2748fa --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh @@ -0,0 +1,31 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=32 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=false \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-sync' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index 5ef7094d9..f72e334d3 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -2,6 +2,7 @@ import uuid from abc import ABC, abstractmethod +from copy import deepcopy from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -23,6 +24,7 @@ class Step(_StepBase): prompt_ids: list[int] | list[Any] = Field(default_factory=list) response_ids: list[int] = Field(default_factory=list) logprobs: list[float] = Field(default_factory=list) + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) chat_completions: list[dict[str, Any]] = Field(default_factory=list) @@ -38,6 +40,9 @@ class Step(_StepBase): # Per-token or scalar advantages advantage: list[float] | float | None = None + # weight version at time of generation (for async training staleness tracking) + weight_version: int | None = None + @property def info(self) -> dict: """Alias for metadata. Auto-initializes to {} if None so mutation works.""" @@ -50,6 +55,7 @@ def info(self, value: dict) -> None: self.metadata = value def model_post_init(self, __context: Any) -> None: + self.chat_completions = deepcopy(self.chat_completions) if self.model_output is None: return # backfill fields like prompt_ids, response_ids, logprobs, etc. @@ -59,17 +65,35 @@ def model_post_init(self, __context: Any) -> None: self.response_ids = self.model_output.completion_ids if len(self.logprobs) == 0 and self.model_output.logprobs is not None: self.logprobs = self.model_output.logprobs + if self.routing_matrices is None and getattr(self.model_output, "routing_matrices", None) is not None: + self.routing_matrices = self.model_output.routing_matrices + if self.weight_version is None and hasattr(self.model_output, "weight_version"): + self.weight_version = self.model_output.weight_version # check that the lengths would match up if len(self.logprobs) > 0: assert len(self.response_ids) == len(self.logprobs), f"length mismatch between response_ids and logprobs, got {len(self.response_ids)}, {len(self.logprobs)}" def to_dict(self) -> dict: + from rllm.tools.tool_base import ToolCall, ToolOutput + + # Helper function to recursively convert ToolCall and ToolOutput objects to dicts + def _serialize_value(value): + if isinstance(value, ToolCall | ToolOutput): + return value.to_dict() + elif isinstance(value, list): + return [_serialize_value(item) for item in value] + elif isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + else: + return value + return { "prompt_ids": self.prompt_ids, "response_ids": self.response_ids, "logprobs": self.logprobs, - "chat_completions": self.chat_completions, + "routing_matrices": self.routing_matrices, + "chat_completions": _serialize_value(self.chat_completions), "observation": self.observation, "thought": self.thought, "action": self.action.action if isinstance(self.action, Action) else self.action, @@ -80,6 +104,7 @@ def to_dict(self) -> dict: "done": self.done, "mc_return": self.mc_return, "advantage": self.advantage, + "weight_version": self.weight_version, } @classmethod @@ -90,6 +115,7 @@ def from_dict(cls, data: dict) -> Step: prompt_ids=data["prompt_ids"], response_ids=data["response_ids"], logprobs=data["logprobs"], + routing_matrices=data.get("routing_matrices"), chat_completions=data["chat_completions"], observation=data["observation"], thought=data["thought"], @@ -100,7 +126,8 @@ def from_dict(cls, data: dict) -> Step: reward=data["reward"], done=data["done"], mc_return=data["mc_return"], - advantage=data["advantage"], + advantage=data.get("advantage", 0.0), + weight_version=data.get("weight_version"), ) @classmethod @@ -109,11 +136,13 @@ def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | Non prompt_ids=model_output.prompt_ids or [], response_ids=model_output.completion_ids or [], logprobs=model_output.logprobs or [], + routing_matrices=getattr(model_output, "routing_matrices", None), chat_completions=(messages or []) + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}], thought=model_output.reasoning or "", action=action, model_response=model_output.content or "", model_output=model_output, + weight_version=model_output.weight_version, ) @@ -259,6 +288,7 @@ class TrajectoryGroup(BaseModel): trajectories: list[Trajectory] group_id: str = "" metadata: list[dict] = Field(default_factory=list) + weight_version: int = 0 @property def group_role(self) -> str: diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py new file mode 100644 index 000000000..288c72fcc --- /dev/null +++ b/rllm/experimental/buffer.py @@ -0,0 +1,273 @@ +"""TrajectoryGroupBuffer for async training. + +Accumulates episodes, processes into ready-to-train trajectory groups, +with optional NVMe offloading for memory management. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import pickle +import tempfile +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tqdm import tqdm + +from rllm.agents.agent import Episode, TrajectoryGroup +from rllm.experimental.common import ( + AlgorithmConfig, + CompactFilteringConfig, + RejectionSamplingConfig, + TransformConfig, + collect_reward_and_advantage_from_trajectory_groups, +) +from rllm.experimental.common.transform import transform_episodes_to_trajectory_groups +from rllm.experimental.metrics import MetricsAggregator +from rllm.experimental.sync_coordinator import SyncCoordinator +from rllm.workflows.workflow import TerminationReason + +logger = logging.getLogger(__name__) + + +@dataclass +class TaskBatch: + """All trajectory groups produced from one task's episodes, plus stripped episodes for UI logging.""" + + groups: list[TrajectoryGroup] + episodes: list[Episode] = field(default_factory=list) + + +class TrajectoryGroupBuffer: + """Accumulates episodes, processes into trajectory groups, yields to training. + + When all rollouts for a task arrive: + 1. Record episode-level metrics to aggregator (before any filtering) + 2. Transform episodes -> trajectory groups + 3. Compact filtering + drop groups with < min_trajs_per_group + 4. Compute advantages + 5. If rejection sampling enabled: drop groups with all-zero advantage + 6. Queue the task batch for training + + Filtered groups are reported directly to the coordinator (which tracks + throttle slots and filter counts). Only non-empty task batches are queued. + All metrics flow through the shared MetricsAggregator. + + Optionally offloads pending episodes and/or queued task batches to + disk to reduce memory pressure (disabled by default). + """ + + def __init__( + self, + group_size: int, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + algorithm_config: AlgorithmConfig, + transform_config: TransformConfig, + cf_config: CompactFilteringConfig, + rs_config: RejectionSamplingConfig, + episode_offload_dir: str | None = None, + trajectory_group_offload_dir: str | None = None, + pbar: tqdm | None = None, + ): + self._group_size = group_size + self._coordinator = coordinator + self._aggregator = aggregator + self._algorithm_config = algorithm_config + self._transform_config = transform_config + self._cf_config = cf_config + self._rs_config = rs_config + self._pbar = pbar + + # Episode offloading: pending episodes serialized to disk + self._episode_offload_dir = episode_offload_dir + if episode_offload_dir: + os.makedirs(episode_offload_dir, exist_ok=True) + self._pending: dict[str, list[Episode | str]] = {} # str = offloaded file path + + # Trajectory group offloading: queued task batches serialized to disk + self._tg_offload_dir = trajectory_group_offload_dir + if trajectory_group_offload_dir: + os.makedirs(trajectory_group_offload_dir, exist_ok=True) + self._queue: asyncio.Queue[TaskBatch | str | None] = asyncio.Queue() + + async def _offload_episode(self, task_id: str, episode: Episode) -> str: + """Serialize episode to disk, return file path.""" + idx = len(self._pending.get(task_id, [])) + path = os.path.join(self._episode_offload_dir, f"{task_id}_{idx}.pkl") + await asyncio.to_thread(self._pickle_dump, path, episode) + return path + + async def _load_pending_episodes(self, task_id: str) -> list[Episode]: + """Load all pending episodes for a task, deserializing offloaded ones.""" + episodes = [] + for item in self._pending.pop(task_id, []): + if isinstance(item, str): + ep = await asyncio.to_thread(self._pickle_load, item) + episodes.append(ep) + else: + episodes.append(item) + return episodes + + async def _offload_task_batch(self, batch: TaskBatch) -> str: + """Serialize task batch to disk, return file path.""" + fd, path = tempfile.mkstemp(dir=self._tg_offload_dir, suffix=".pkl") + os.close(fd) + await asyncio.to_thread(self._pickle_dump, path, batch) + return path + + async def _load_task_batch(self, item: TaskBatch | str) -> TaskBatch: + """Load task batch, deserializing if offloaded.""" + if isinstance(item, str): + return await asyncio.to_thread(self._pickle_load, item) + return item + + @staticmethod + def _pickle_dump(path: str, obj) -> None: + with open(path, "wb") as f: + pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + + @staticmethod + def _pickle_load(path: str): + with open(path, "rb") as f: + obj = pickle.load(f) + os.remove(path) + return obj + + async def add_episode(self, task_id: str, episode: Episode) -> bool: + """Add episode. When group completes, process and queue task batch.""" + # Offload episode to disk if enabled + if self._episode_offload_dir: + path = await self._offload_episode(task_id, episode) + self._pending.setdefault(task_id, []).append(path) + else: + self._pending.setdefault(task_id, []).append(episode) + + if len(self._pending[task_id]) < self._group_size: + return False + + # Group complete — tick progress bar + if self._pbar is not None: + self._pbar.update(1) + + # Load all episodes + if self._episode_offload_dir: + episodes = await self._load_pending_episodes(task_id) + else: + episodes = self._pending.pop(task_id, []) + + weight_version = self._min_weight_version(episodes) + + # 1. Record episode-level metrics (includes filtered tasks) + self._record_episode_metrics(episodes) + + # 2. Transform episodes -> trajectory groups + traj_groups, transform_metrics = transform_episodes_to_trajectory_groups( + episodes, + self._transform_config, + self._cf_config, + ) + self._aggregator.record_dict(transform_metrics) + + # 3. Drop groups with too few trajectories + before_min_traj = len(traj_groups) + traj_groups = [g for g in traj_groups if len(g.trajectories) >= self._rs_config.min_trajs_per_group] + self._aggregator.record("groups/dropped_min_trajs", before_min_traj - len(traj_groups)) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 4. Compute advantages + adv_metrics = collect_reward_and_advantage_from_trajectory_groups( + traj_groups, + self._algorithm_config, + ) + self._aggregator.record_dict(adv_metrics) + + # 5. Rejection sampling: drop groups with all-zero advantage + filtered_zero_adv = 0 + if self._rs_config.filter_uniform_groups: + before_adv = len(traj_groups) + traj_groups = [g for g in traj_groups if any(abs(step.advantage) > 1e-8 for traj in g.trajectories for step in traj.steps if step.advantage is not None)] + filtered_zero_adv = before_adv - len(traj_groups) + self._aggregator.record("groups/dropped_zero_adv", filtered_zero_adv) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 6. Set weight version and queue + for g in traj_groups: + g.weight_version = weight_version + + batch = TaskBatch(groups=traj_groups, episodes=episodes) + if self._tg_offload_dir: + await self._queue.put(await self._offload_task_batch(batch)) + else: + await self._queue.put(batch) + + return True + + async def get(self) -> TaskBatch | None: + """Get next task batch. Returns None when generation is done and buffer is drained.""" + item = await self._queue.get() + if item is None: + return None + return await self._load_task_batch(item) + + def mark_generation_complete(self) -> None: + """Signal that generation is finished. Flushes incomplete groups and enqueues a sentinel.""" + for task_id in list(self._pending.keys()): + items = self._pending.pop(task_id, []) + for item in items: + if isinstance(item, str): + try: + os.remove(item) + except OSError: + pass + self._coordinator.on_group_filtered() + self._queue.put_nowait(None) + + def stats(self) -> dict: + return { + "async/buffer_qsize": self._queue.qsize(), + "async/buffer_pending": len(self._pending), + } + + def _record_episode_metrics(self, episodes: list[Episode]) -> None: + """Record episode-level metrics to aggregator (all episodes, including filtered).""" + for ep in episodes: + reason = ep.termination_reason or TerminationReason.UNKNOWN + for r in TerminationReason: + self._aggregator.record( + f"episode/termination_reason/{r.value}", + 1.0 if reason == r else 0.0, + ) + for k, v in ep.metrics.items(): + try: + self._aggregator.record(f"episode/{k}", float(v)) + except (TypeError, ValueError): + continue + + # Episode-level totals across all trajectories + total_turns = sum(len(traj.steps) for traj in ep.trajectories) + total_prompt_tokens = sum(len(s.prompt_ids) for traj in ep.trajectories for s in traj.steps) + total_response_tokens = sum(len(s.response_ids) for traj in ep.trajectories for s in traj.steps) + self._aggregator.record("episode/num_turns", total_turns) + self._aggregator.record("episode/prompt_tokens", total_prompt_tokens) + self._aggregator.record("episode/response_tokens", total_response_tokens) + self._aggregator.record("episode/correct", 1.0 if ep.is_correct else 0.0) + + @staticmethod + def _min_weight_version(episodes: list[Episode]) -> int: + min_v = float("inf") + for ep in episodes: + for traj in ep.trajectories: + for step in traj.steps: + if step.weight_version is not None: + min_v = min(min_v, step.weight_version) + return int(min_v) if min_v != float("inf") else 0 diff --git a/rllm/experimental/common/__init__.py b/rllm/experimental/common/__init__.py index ed169b372..b75f43c90 100644 --- a/rllm/experimental/common/__init__.py +++ b/rllm/experimental/common/__init__.py @@ -7,8 +7,10 @@ from rllm.experimental.common.advantage import collect_reward_and_advantage_from_trajectory_groups from rllm.experimental.common.config import ( AlgorithmConfig, + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, + RolloutCorrectionConfig, TransformConfig, rLLMAdvantageEstimator, ) @@ -24,8 +26,10 @@ __all__ = [ # Config + "AsyncTrainingConfig", "CompactFilteringConfig", "RejectionSamplingConfig", + "RolloutCorrectionConfig", "TransformConfig", "AlgorithmConfig", # Transform pipeline diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 1e82fa08c..dee2f2d13 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -8,6 +8,38 @@ from rllm.workflows.workflow import TerminationReason +@dataclass +class AsyncTrainingConfig: + """Controls the async training behavior spectrum. + + When `enable` is False, the trainer uses the current synchronous pipeline. + When `enable` is True, the trainer runs concurrent generation + training + with group-level streaming and dispatch-time throttle. + + Behavior spectrum: + - staleness_threshold=0, trigger_parameter_sync_step=1: On-policy + - staleness_threshold=0, trigger_parameter_sync_step=K: Stream off-policy + - staleness_threshold>0, partial_rollout=False: Async with staleness + - staleness_threshold>0, partial_rollout=True: Async with partial rollout + """ + + enable: bool = False + mini_batch_size: int = 1 # episode groups per optimizer step + fwd_bwd_group_size: int | None = None # task batches per forward-backward pass (default: mini_batch_size) + staleness_threshold: float = 0.0 # 0.0 = on-policy. Controls dispatch throttle quota. + trigger_parameter_sync_step: int = 1 # optimizer steps between weight sync + version bump + partial_rollout: bool = True # enable turn-level gating during weight sync + episode_offload_dir: str | None = None # NVMe offload dir for pending episodes (None = disabled) + trajectory_group_offload_dir: str | None = None # NVMe offload dir for queued task batches (None = disabled) + + def __post_init__(self): + if self.fwd_bwd_group_size is None: + self.fwd_bwd_group_size = self.mini_batch_size + if self.enable: + assert self.fwd_bwd_group_size >= 1 + assert self.mini_batch_size % self.fwd_bwd_group_size == 0, f"mini_batch_size ({self.mini_batch_size}) must be divisible by fwd_bwd_group_size ({self.fwd_bwd_group_size})" + + @dataclass class CompactFilteringConfig: """Configuration for compact filtering of episodes based on termination reasons. @@ -93,6 +125,30 @@ class RejectionSamplingConfig: # For "episode" mode (verl compatibility): minimum number of tasks with partial solves before proceeding min_partial_solve_tasks: int = 1 + # Filter out episode groups where all rollouts have the same is_correct (no gradient signal). + # Applied at the accumulator level in async training, before groups enter the buffer. + filter_uniform_groups: bool = False + + +@dataclass +class RolloutCorrectionConfig: + """Configuration for rollout correction (TIS, proximal forward passes). + + Backend-agnostic — each backend interprets these according to its infrastructure. + + Attributes: + tis_mode: None = disabled (string loss names, current behavior). + "token" or "sequence" = enable custom callable loss with TIS at that level. + bypass_mode: When True, use rollout (inference) logprobs as π_old — no + proximal forward pass. When False, compute π_old via policy.forward() + (3-policy / decoupled PPO). + tis_cap: Upper clamp on the TIS importance weight. + """ + + tis_mode: str | None = None + bypass_mode: bool = True + tis_cap: float = 5.0 + class rLLMAdvantageEstimator(str, Enum): """ @@ -143,6 +199,14 @@ class AlgorithmConfig: lr_schedule: Literal["linear", "cosine", "constant"] = "constant" warmup_steps_ratio: float = 0.0 + # Custom loss / rollout correction fields (used by Fireworks backend with cookbook losses) + kl_beta: float = 0.0 + eps_clip: float = 0.2 + eps_clip_high: float | None = None + loss_agg_mode: Literal["token_mean", "seq_mean_token_sum", "seq_mean_token_mean", None] = None + rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig) + router_replay: bool = False + @classmethod def from_config(cls, config: DictConfig) -> "AlgorithmConfig": """Create an AlgorithmConfig from a dictionary configuration. @@ -152,6 +216,12 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": Returns: AlgorithmConfig: The AlgorithmConfig built from the configuration. """ + rc_section = config.rllm.algorithm.get("rollout_correction", {}) + rollout_correction = RolloutCorrectionConfig( + tis_mode=rc_section.get("tis_mode", None), + bypass_mode=rc_section.get("bypass_mode", True), + tis_cap=rc_section.get("tis_cap", 5.0), + ) return cls( estimator=rLLMAdvantageEstimator(config.algorithm.adv_estimator), stepwise_advantage_mode=config.rllm.stepwise_advantage.mode, @@ -161,6 +231,12 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": loss_fn=config.rllm.algorithm.get("loss_fn", None), lr_schedule=config.rllm.algorithm.get("lr_schedule", "constant"), warmup_steps_ratio=config.rllm.algorithm.get("warmup_steps_ratio", 0.0), + kl_beta=config.rllm.algorithm.get("kl_beta", 0.0), + eps_clip=config.rllm.algorithm.get("eps_clip", 0.2), + eps_clip_high=config.rllm.algorithm.get("eps_clip_high", None), + loss_agg_mode=config.rllm.algorithm.get("loss_agg_mode", None), + rollout_correction=rollout_correction, + router_replay=config.rllm.algorithm.get("router_replay", False), ) def __post_init__(self): diff --git a/rllm/experimental/common/transform.py b/rllm/experimental/common/transform.py index 296ecda35..f13e70f29 100644 --- a/rllm/experimental/common/transform.py +++ b/rllm/experimental/common/transform.py @@ -152,7 +152,7 @@ def _build_trajectory_groups(episodes: list[Episode], compact_filtering_config: """ -def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "grouping") -> dict: +def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "groups") -> dict: """ Get metrics for the transformation pipeline. """ @@ -182,12 +182,11 @@ def _default_traj_grouping_hook(episodes: list[Episode], transform_config: Trans """ trajectory_groups = _build_trajectory_groups(episodes, compact_filtering_config) # part 1 reward_warnings = _validate_and_propagate_rewards(trajectory_groups, transform_config) # part 2 - - for warning in reward_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(reward_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar warnings with reward validation") + if reward_warnings: + for warning in reward_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(reward_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar reward validation warnings") return trajectory_groups @@ -196,7 +195,7 @@ def transform_episodes_to_trajectory_groups( episodes: list[Episode], transform_config: TransformConfig, compact_filtering_config: CompactFilteringConfig | None = None, - metrics_prefix: str = "grouping", + metrics_prefix: str = "groups", traj_grouping_hook: Callable[[list[Episode], TransformConfig, CompactFilteringConfig | None], list[TrajectoryGroup]] = _default_traj_grouping_hook, ) -> tuple[list[TrajectoryGroup], dict]: """ @@ -234,12 +233,11 @@ def transform_episodes_to_trajectory_groups( # Step 1: Name imputation rename_warnings = _impute_trajectory_names(episodes, transform_config) - - for warning in rename_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(rename_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar warnings with trajectory names") + if rename_warnings: + for warning in rename_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(rename_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar trajectory name warnings") # Step 2: Invoke the trajectory grouping hook groups = traj_grouping_hook(episodes, transform_config, compact_filtering_config) diff --git a/rllm/experimental/common/visualization.py b/rllm/experimental/common/visualization.py index cee8e9d01..d7d45d242 100644 --- a/rllm/experimental/common/visualization.py +++ b/rllm/experimental/common/visualization.py @@ -24,6 +24,40 @@ class VisualizationConfig: failure_style: dict[str, Any] = field(default_factory=lambda: {"fg": "red", "bold": True}) +def print_metrics_table(metrics: dict, step: int, title: str | None = None) -> None: + """Print metrics as a formatted Rich table with fallback to plain text.""" + try: + from rich.console import Console + from rich.table import Table + + table = Table(title=title or f"Step {step}", show_header=True, header_style="bold magenta") + table.add_column("Metric", style="cyan", no_wrap=False) + table.add_column("Value", justify="right", style="green") + + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + table.add_row(key, value_str) + + Console().print(table) + except ImportError: + print(f"\n{title or f'Step {step}'}") + print("=" * 60) + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + print(f"{key:40s} {value_str:>15s}") + print("=" * 60) + + def colorful_print(string: str, *args, **kwargs) -> None: end = kwargs.pop("end", "\n") print(click.style(string, *args, **kwargs), end=end, flush=True) diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index 152310f9b..9aaee0296 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -56,8 +56,16 @@ algorithm: # When true, always use pre-computed step.advantage from the workflow (e.g. distillation) # and skip advantage computation (GRPO/REINFORCE). Missing advantages default to 0. use_precomputed_advantage: false - # for tinker backend only (avaiable options: importance_sampling, ppo, cispo, dro, cross_entropy) - loss_fn: null + loss_fn: null # [null, importance_sampling, ppo, cispo, dro, cross_entropy] + loss_agg_mode: null # [null, token-mean, seq-mean-token-sum, seq-mean-token-mean] + kl_beta: 0.0 # KL penalty coefficient + eps_clip: 0.2 # PPO clip epsilon + eps_clip_high: null # Asymmetric upper clip bound (null = symmetric) + router_replay: false # Router Replay (R3): replay MoE expert routing from inference during training + rollout_correction: + bypass_mode: true # true = use rollout logprobs as pi_old (2-policy), false = proximal forward (3-policy) + tis_mode: null # null = disabled, "token" or "sequence" = TIS importance sampling level + tis_cap: 5.0 # Upper clamp on TIS importance weight # Stepwise advantage # TODO(listar2000): deprecate the `per_step` mode and refactor this config. @@ -93,6 +101,7 @@ rejection_sample: multiplier: 1 min_partial_solve_tasks: 1 min_trajs_per_group: 2 + filter_uniform_groups: false # SDK Configuration # DEPRECATED: This section is only kept for backward compatibility with @@ -153,6 +162,17 @@ gateway: host: null # Auto-detects routable IP; set explicitly to override db_path: null # Defaults to temp file +# Async Training Configuration +async_training: + enable: false + mini_batch_size: 1 + fwd_bwd_group_size: null # task batches per forward-backward pass (default: mini_batch_size) + staleness_threshold: 0.0 + trigger_parameter_sync_step: 1 + partial_rollout: true + episode_offload_dir: null # NVMe offload dir for pending episodes (null = disabled) + trajectory_group_offload_dir: null # NVMe offload dir for queued task batches (null = disabled) + # Episode Logging Configuration episode_logging: log_episodes: false diff --git a/rllm/experimental/engine/agent_flow_engine.py b/rllm/experimental/engine/agent_flow_engine.py index 60c3ccdd6..1b582dc31 100644 --- a/rllm/experimental/engine/agent_flow_engine.py +++ b/rllm/experimental/engine/agent_flow_engine.py @@ -74,6 +74,7 @@ def __init__( self.raise_on_error = raise_on_error self.episode_logger = episode_logger self.executor = ThreadPoolExecutor(max_workers=n_parallel_tasks) + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Raise the file descriptor limit to avoid "Too many open files" when # running many parallel agent flows with individual HTTP clients. @@ -117,7 +118,7 @@ async def execute_tasks( for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - futures.append(self._process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) with tqdm(total=len(tasks), desc="Generating trajectories") as pbar: for future in asyncio.as_completed(futures): @@ -141,7 +142,7 @@ async def execute_tasks( return ordered_results - async def _process_task_with_retry( + async def process_task_with_retry( self, task: dict, task_id: str, @@ -150,64 +151,61 @@ async def _process_task_with_retry( is_validation: bool = False, ) -> tuple[str, int, int, Episode]: """Process a single task with retry logic.""" - for retry_attempt in range(1, self.retry_limit + 1): - uid = f"{task_id}:{rollout_idx}" - try: - episode = await self._run_single(task, uid, is_validation=is_validation) - episode.id = uid - episode.task = task - - # Display rewards - reward_strs = [] - for traj in episode.trajectories: - reward = "N/A" - if traj.reward is not None: - reward = f"{traj.reward:.1f}" - elif len(traj.steps) > 0: - reward = f"{traj.steps[-1].reward:.1f}" - reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", - ) - - return task_id, rollout_idx, result_idx, episode - - except Exception as e: - logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) - if retry_attempt < self.retry_limit: - continue - - if self.raise_on_error: - raise - - # Return an error episode - return ( - task_id, - rollout_idx, - result_idx, - Episode( - id=uid, - task=task, - is_correct=False, - termination_reason=TerminationReason.ERROR, - metadata={"error": {"message": str(e)}}, - ), - ) - - # Should not reach here, but satisfy type checker - raise RuntimeError(f"[{uid}] Exhausted all retries") + async with self._semaphore: + for retry_attempt in range(1, self.retry_limit + 1): + uid = f"{task_id}:{rollout_idx}" + try: + episode = await self._run_single(task, uid, is_validation=is_validation) + episode.id = uid + episode.task = task + + # Display rewards + reward_strs = [] + for traj in episode.trajectories: + reward = "N/A" + if traj.reward is not None: + reward = f"{traj.reward:.1f}" + elif len(traj.steps) > 0: + reward = f"{traj.steps[-1].reward:.1f}" + reward_strs.append(f"{traj.name}: {reward}") + colorful_print( + f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", + fg="green" if episode.is_correct else "yellow", + ) + + return task_id, rollout_idx, result_idx, episode + + except Exception as e: + logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) + if retry_attempt < self.retry_limit: + continue + + if self.raise_on_error: + raise + + # Return an error episode + return ( + task_id, + rollout_idx, + result_idx, + Episode( + id=uid, + task=task, + is_correct=False, + termination_reason=TerminationReason.ERROR, + metadata={"error": {"message": str(e)}}, + ), + ) + + # Should not reach here, but satisfy type checker + raise RuntimeError(f"[{uid}] Exhausted all retries") async def _run_single(self, task: dict, uid: str, is_validation: bool = False) -> Episode: """Run a single AgentFlow task: execute, evaluate, enrich.""" loop = asyncio.get_event_loop() - # 1. Create gateway session (run in executor to avoid blocking event loop) - await loop.run_in_executor( - self.executor, - self.gateway.create_session, - uid, - ) + # 1. Create gateway session + await self.gateway.acreate_session(uid, is_validation=is_validation) session_url = self.gateway.get_session_url(uid) # 2. Build config @@ -239,12 +237,8 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) - traj.reward = eval_output.reward episode.is_correct = eval_output.is_correct - # 5. Retrieve traces from gateway (run in executor to avoid blocking event loop) - traces = await loop.run_in_executor( - self.executor, - self.gateway.get_traces, - uid, - ) + # 5. Retrieve traces from gateway + traces = await self.gateway.aget_traces(uid) # 6. Enrich episode with token data enriched = self._enrich_episode(episode, traces, uid, task) diff --git a/rllm/experimental/engine/gateway_manager.py b/rllm/experimental/engine/gateway_manager.py index 44e75b474..abcc24191 100644 --- a/rllm/experimental/engine/gateway_manager.py +++ b/rllm/experimental/engine/gateway_manager.py @@ -18,7 +18,7 @@ import time from typing import TYPE_CHECKING, Any -from rllm_model_gateway.client import GatewayClient +from rllm_model_gateway.client import AsyncGatewayClient, GatewayClient from rllm_model_gateway.models import TraceRecord if TYPE_CHECKING: @@ -89,6 +89,7 @@ def __init__(self, config: DictConfig, mode: str = "thread") -> None: self._server: Any = None # uvicorn.Server when using thread mode self._local_handler: Any = None # in-process handler for tinker self._client: GatewayClient | None = None + self._async_client: AsyncGatewayClient | None = None # Per-mode sampling params (extracted from rollout engine in start()) self._train_sampling_params: dict[str, Any] = {} @@ -100,10 +101,18 @@ def gateway_url(self) -> str: @property def client(self) -> GatewayClient: + """Sync client for lifecycle operations (start, stop, health polling).""" if self._client is None: self._client = GatewayClient(self.gateway_url) return self._client + @property + def async_client(self) -> AsyncGatewayClient: + """Async client for runtime operations (sessions, traces).""" + if self._async_client is None: + self._async_client = AsyncGatewayClient(self.gateway_url) + return self._async_client + # -- Lifecycle ----------------------------------------------------------- def start(self, rollout_engine: RolloutEngine) -> None: @@ -171,6 +180,16 @@ def get_traces(self, session_id: str) -> list[TraceRecord]: self.client.flush() return self.client.get_session_traces(session_id) + # -- Async session / trace API ------------------------------------------- + + async def acreate_session(self, session_id: str, is_validation: bool = False) -> str: + sp = self._val_sampling_params if is_validation else self._train_sampling_params + return await self.async_client.create_session(session_id=session_id, sampling_params=sp or None) + + async def aget_traces(self, session_id: str) -> list[TraceRecord]: + await self.async_client.flush() + return await self.async_client.get_session_traces(session_id) + # -- Worker setup -------------------------------------------------------- def _ensure_workers(self, rollout_engine: RolloutEngine) -> list[str]: diff --git a/rllm/experimental/engine/remote_agent_flow_engine.py b/rllm/experimental/engine/remote_agent_flow_engine.py index c59adc814..0bf67d6f6 100644 --- a/rllm/experimental/engine/remote_agent_flow_engine.py +++ b/rllm/experimental/engine/remote_agent_flow_engine.py @@ -5,6 +5,7 @@ converting gateway traces to training Steps. """ +import asyncio import logging import uuid from collections import defaultdict @@ -31,12 +32,15 @@ def __init__( runtime: RemoteAgentRuntime, gateway: GatewayManager, session_timeout: float = 900.0, + n_parallel_tasks: int = 128, episode_logger: EpisodeLogger | None = None, ) -> None: self.runtime = runtime self.gateway = gateway self.session_timeout = session_timeout + self.n_parallel_tasks = n_parallel_tasks self.episode_logger = episode_logger + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Training step tracking (set by set_training_step) self.current_step = 0 @@ -55,59 +59,24 @@ async def execute_tasks( is_validation: bool = False, **kwargs, ) -> list[Episode]: - """Submit tasks to remote runtime, gather results, build Episodes from gateway traces. - - 1. Prepare submissions (create gateway sessions) - 2. Submit all and gather results concurrently via runtime - 3. Retrieve traces from gateway + build Episodes - """ + """Submit tasks to remote runtime, gather results, build Episodes from gateway traces.""" if task_ids is None: task_ids = [str(uuid.uuid4()) for _ in tasks] - # Phase 1: Prepare submissions task_id_counter: dict[str, int] = defaultdict(int) - submissions: list[TaskSubmission] = [] - # Map session_id -> (idx, uid, task) for result correlation - session_metadata: dict[str, tuple[int, str, dict]] = {} + results: list[Episode | None] = [None] * len(tasks) + futures = [] for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - uid = f"{task_id}:{rollout_idx}" - session_id = str(uuid.uuid4()) - - self.gateway.create_session(session_id, is_validation=is_validation) - session_url = self.gateway.get_session_url(session_id) - - submissions.append( - TaskSubmission( - task=task, - session_id=session_id, - task_id=task_id, - inference_url=session_url, - ) - ) - session_metadata[session_id] = (idx, uid, task) - - # Phase 2: Submit all and gather results concurrently - logger.info("Submitting %d tasks to remote runtime (timeout=%.0fs)", len(submissions), self.session_timeout) - remote_results = await self.runtime.execute_tasks(submissions, timeout=self.session_timeout) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) - # Phase 3: Retrieve traces from gateway + build Episodes (match by session_id) - episode_map: dict[int, Episode] = {} + for future in asyncio.as_completed(futures): + task_id, rollout_idx, idx, episode = await future + results[idx] = episode - for result in remote_results: - idx, uid, task = session_metadata[result.session_id] - if not result.finished: - logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) - result.reward = 0.0 - traces = self.gateway.get_traces(result.session_id) - episode = _build_episode(traces, result, uid, task) - if not result.finished: - episode.metadata["error"] = {"message": result.error or "Unknown error"} - episode_map[idx] = episode - - episodes = [episode_map[i] for i in range(len(tasks))] + episodes: list[Episode] = results # type: ignore[assignment] # Log episodes if logger is provided if self.episode_logger is not None: @@ -123,6 +92,43 @@ async def execute_tasks( return episodes + async def process_task_with_retry( + self, + task: dict, + task_id: str, + rollout_idx: int, + result_idx: int, + **kwargs, + ) -> tuple[str, int, int, Episode]: + """Process a single task with concurrency control.""" + async with self._semaphore: + uid = f"{task_id}:{rollout_idx}" + session_id = str(uuid.uuid4()) + is_validation = kwargs.get("is_validation", False) + + await self.gateway.acreate_session(session_id, is_validation=is_validation) + session_url = self.gateway.get_session_url(session_id) + + submission = TaskSubmission( + task=task, + session_id=session_id, + task_id=task_id, + inference_url=session_url, + ) + results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) + result = results[0] + + if not result.finished: + logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) + result.reward = 0.0 + + traces = await self.gateway.aget_traces(session_id) + episode = _build_episode(traces, result, uid, task) + if not result.finished: + episode.metadata["error"] = {"message": result.error or "Unknown error"} + + return task_id, rollout_idx, result_idx, episode + def shutdown(self) -> None: """No local resources to clean up (runtime shutdown is separate).""" pass diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index da03c4d4f..037cbd9a3 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -12,12 +12,12 @@ from rllm.agents.agent import Episode from rllm.experimental.rollout import RolloutEngine -from rllm.utils import colorful_print from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow # Avoid hard dependency on verl at import time; only for typing if TYPE_CHECKING: + from omegaconf import DictConfig from verl import DataProto from rllm.utils.episode_logger import EpisodeLogger @@ -31,7 +31,7 @@ def __init__( workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, - config=None, + config: DictConfig | None = None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, @@ -104,6 +104,7 @@ async def initialize_pool(self): assert self.executor is not None, "executor is not initialized" if self.workflow_queue is not None: return + logger.info(f"[WorkflowEngine] Initializing pool with {self.n_parallel_tasks} workflows") self.workflow_queue = asyncio.Queue(maxsize=self.n_parallel_tasks) for i in range(self.n_parallel_tasks): workflow = self.workflow_cls( @@ -114,6 +115,7 @@ async def initialize_pool(self): ) assert workflow.is_multithread_safe(), "Workflows must contain only thread-save environments" self.workflow_queue.put_nowait(workflow) + logger.info(f"[WorkflowEngine] Pool initialized. Queue size: {self.workflow_queue.qsize()}") async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs) -> tuple[str, int, int, Episode]: """Process a single task rollout with retry logic based on termination reasons. @@ -132,10 +134,12 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i Exception: If task fails permanently after retry_limit attempts and raise_on_error is True. """ assert self.workflow_queue is not None, "workflow_queue is not initialized" + logger.debug(f"[WorkflowEngine] Waiting for workflow from queue. Available: {self.workflow_queue.qsize()}") workflow = await self.workflow_queue.get() try: for retry_attempt in range(1, self.retry_limit + 1): uid = f"{task_id}:{rollout_idx}" + logger.debug(f"[WorkflowEngine] [{uid}] Starting attempt {retry_attempt}/{self.retry_limit}") workflow.reset(task=task, uid=uid) episode = await workflow.run_with_termination_handling(task=task, uid=uid, **kwargs) @@ -152,24 +156,21 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i elif len(traj.steps) > 0: reward = f"{traj.steps[-1].reward:.1f}" reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", - ) + logger.debug(f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}") if episode.termination_reason != TerminationReason.ERROR: return task_id, rollout_idx, result_idx, episode error_tb = episode.info.get("error", {}).get("traceback") if error_tb: - print(error_tb) + logger.error(f"[WorkflowEngine] [{uid}] Error on attempt {retry_attempt}/{self.retry_limit}:\n{error_tb}") if retry_attempt < self.retry_limit: - print(f"[{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") + logger.warning(f"[WorkflowEngine] [{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") continue if not self.raise_on_error: - print(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") + logger.error(f"[WorkflowEngine] [{uid}] Rollout failed permanently after {self.retry_limit} attempts.") else: raise Exception(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") @@ -177,6 +178,7 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i finally: await self.workflow_queue.put(workflow) + logger.debug(f"[WorkflowEngine] Returned workflow to queue. Available: {self.workflow_queue.qsize()}") async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = None, is_validation: bool = False, **kwargs) -> list[Episode]: """Run asynchronous workflow execution with retry logic for multiple tasks. diff --git a/rllm/experimental/metrics.py b/rllm/experimental/metrics.py new file mode 100644 index 000000000..fe2ce1ef4 --- /dev/null +++ b/rllm/experimental/metrics.py @@ -0,0 +1,119 @@ +"""MetricsAggregator for async training. + +Accumulates metric observations from multiple sources (buffer, training loop, +coordinator) and reduces them with per-key aggregation rules at flush time. +""" + +from __future__ import annotations + +from collections import defaultdict + +import numpy as np + +# Keys that should be summed rather than averaged. +_SUM_KEYS: set[str] = { + "groups/num_trajs_before_filter", + "groups/num_trajs_after_filter", + "groups/num_groups", + "groups/dropped_min_trajs", + "groups/dropped_zero_adv", +} + +# Prefixes where "last value" is the correct reduction. +_LAST_PREFIXES: tuple[str, ...] = ( + "time/", + "train/", + "progress/", + "async/", +) + +# Prefixes where "mean" is the correct reduction. +_MEAN_PREFIXES: tuple[str, ...] = ("episode/",) + + +def _infer_rule(key: str) -> str: + """Infer aggregation rule from metric key name. + + Resolution order: + 1. Explicit sum keys + 2. Prefix-based rules (last or mean) + 3. Keyword-based rules (/max, /min, /mean, /avg, /std, /fraction) + 4. Default: mean + """ + if key in _SUM_KEYS: + return "sum" + + for prefix in _LAST_PREFIXES: + if key.startswith(prefix): + return "last" + + for prefix in _MEAN_PREFIXES: + if key.startswith(prefix): + return "mean" + + # Keyword inference from the key name + if "/max" in key: + return "max" + if "/min" in key: + return "min" + if "/mean" in key or "/avg" in key: + return "mean" + if "/std" in key or "/fraction" in key: + return "mean" + + return "mean" + + +def _reduce(rule: str, values: list[float]) -> float: + if rule == "mean": + return sum(values) / len(values) + if rule == "sum": + return sum(values) + if rule == "max": + return max(values) + if rule == "min": + return min(values) + if rule == "last": + return values[-1] + return sum(values) / len(values) + + +class MetricsAggregator: + """Accumulates metric observations and flushes as an aggregated plain dict. + + Usage:: + + agg = MetricsAggregator() + + # record from various sources + agg.record("episode/queue_wait", 0.3) + agg.record("episode/queue_wait", 0.5) + agg.record_dict(transform_metrics) + + # at log time + plain_dict = agg.flush() # reduces, clears, returns dict + """ + + def __init__(self) -> None: + self._values: dict[str, list[float]] = defaultdict(list) + + def record(self, key: str, value: float) -> None: + """Record a single metric observation.""" + self._values[key].append(float(value)) + + def record_dict(self, metrics: dict) -> None: + """Record all numeric values from a dict, coercing types.""" + for k, v in metrics.items(): + if isinstance(v, int | float): + self._values[k].append(float(v)) + elif isinstance(v, np.number): + self._values[k].append(float(v)) + + def flush(self) -> dict[str, float]: + """Reduce all accumulated values and return a plain dict. Clears state.""" + result = {} + for key, values in self._values.items(): + if values: + result[key] = _reduce(_infer_rule(key), values) + self._values.clear() + return result diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 0c8491dbe..ebdb79a0b 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -205,6 +205,10 @@ async def on_epoch_end(self, trainer_state: TrainerState) -> None: """Hook method called at the end of an epoch.""" pass + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Hook called immediately after update_policy() for weight sync.""" + pass + async def on_validation_start(self, trainer_state: TrainerState) -> bool: """Hook method called at the start of validation. diff --git a/rllm/experimental/rollout/__init__.py b/rllm/experimental/rollout/__init__.py index 6e5c4d681..50ab03477 100644 --- a/rllm/experimental/rollout/__init__.py +++ b/rllm/experimental/rollout/__init__.py @@ -9,11 +9,10 @@ __all__ = [ "ModelOutput", - # Rollout engines "RolloutEngine", "TinkerEngine", "VerlEngine", - # Token input/output types + # Token types "TokenInput", "TokenOutput", "TinkerTokenInput", @@ -30,7 +29,10 @@ def __getattr__(name): return _TinkerEngine if name == "VerlEngine": - from .verl_engine import VerlEngine as _VerlEngine + try: + from .verl_engine import VerlEngine as _VerlEngine - return _VerlEngine + return _VerlEngine + except Exception: + raise AttributeError(name) from None raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index 7146be416..3837d59ec 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import TYPE_CHECKING @@ -8,6 +9,8 @@ from rllm.parser import ChatTemplateParser from rllm.tools.tool_base import ToolCall +logger = logging.getLogger(__name__) + @dataclass class ModelOutput: @@ -20,9 +23,12 @@ class ModelOutput: multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None + weight_version: int | None = None # policy version at time of generation + metrics: dict | None = None # per-turn server metrics (e.g. ttft, queue durations) def to_dict(self): return { @@ -38,6 +44,8 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, + "weight_version": self.weight_version, + "metrics": self.metrics, } @classmethod @@ -57,6 +65,8 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), + weight_version=data.get("weight_version"), + metrics=data.get("metrics"), ) @@ -66,10 +76,16 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - pass + self.weight_version: int = 0 + + # --- Model response --- + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") + result = await self._get_model_response(messages, **kwargs) + result.weight_version = self.weight_version + return result def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: """ @@ -81,13 +97,13 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa """Obtain the token output from the given token input.""" raise NotImplementedError("get_token_output_from_token_input is not implemented") + @property + def supports_token_in_token_out(self) -> bool: + """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" + return False + async def wake_up(self): pass async def sleep(self): pass - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False diff --git a/rllm/experimental/rollout/tinker_engine.py b/rllm/experimental/rollout/tinker_engine.py index 77cdc24a2..0665777d2 100644 --- a/rllm/experimental/rollout/tinker_engine.py +++ b/rllm/experimental/rollout/tinker_engine.py @@ -186,6 +186,7 @@ def __init__( reasoning_effort: The effort level for reasoning (used when bypass_render_with_parser=True) renderer_name: The name of the renderer to use (used when bypass_render_with_parser=True) """ + super().__init__() self.base_url = base_url self.model_name = model_name self.max_prompt_length = max_prompt_length @@ -365,7 +366,7 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp ) @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: """ Generate model response for a given set of messages. diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py index 4540a8d1e..a9b5d99e3 100644 --- a/rllm/experimental/rollout/verl_engine.py +++ b/rllm/experimental/rollout/verl_engine.py @@ -24,6 +24,7 @@ def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokeni # reconstruct the servers list from the server_addresses and server_handles (Verl 0.7.0+) servers = zip(rollout_manager.server_addresses, rollout_manager.server_handles, strict=True) self.server_manager = AsyncLLMServerManager(config, servers=servers, load_balancer_handle=rollout_manager.global_load_balancer) + self.tokenizer = tokenizer self.processor = processor self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=config.get("rllm", {}).get("disable_thinking", False)) @@ -74,12 +75,13 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa return token_output @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: # these go to the parser tools = kwargs.pop("tools", []) accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", "medium") - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) + prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning, reasoning_effort=reasoning_effort) request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py new file mode 100644 index 000000000..1405f3c0a --- /dev/null +++ b/rllm/experimental/sync_coordinator.py @@ -0,0 +1,129 @@ +"""SyncCoordinator: manages rollout quotas and parameter sync timing for fully-async training.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + + +@dataclass +class SyncCoordinatorConfig: + mini_batch_size: int # episode groups per optimizer step + group_size: int # episodes per group (rollout.n) + staleness_threshold: float + trigger_parameter_sync_step: int + + @property + def max_rollout_quota(self) -> int: + """Max dispatches per sync window (Verl/AReaL formulation).""" + return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.mini_batch_size) + + +class SyncCoordinator: + """Coordinates rollout scheduling and parameter sync between generation and training loops. + + Uses a per-sync-window dispatch counter (matching Verl/AReaL). The counter + resets only on weight sync, not on consume. This guarantees zero staleness + when staleness_threshold=0. + """ + + def __init__(self, config: SyncCoordinatorConfig): + self.config = config + + self._weight_version: int = 0 + self._quota_used: int = 0 # groups counting toward current sync window quota (includes carryover) + self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered + self._steps_since_sync: int = 0 + self._total_syncs: int = 0 + + # Throttle — blocks generation when dispatched_since_sync >= max_rollout_quota + self._throttle_event: asyncio.Event = asyncio.Event() + self._throttle_event.set() + + # Generation pause — blocks generation during validation or weight sync + self._generation_paused: asyncio.Event = asyncio.Event() + self._generation_paused.set() + + # Tracks in-flight async rollout tasks for drain/wait logic + self._in_flight_tasks: set[asyncio.Task] = set() + + @property + def weight_version(self) -> int: + return self._weight_version + + # --- Throttle --- + + def on_group_dispatched(self) -> None: + """Generation loop dispatched one prompt (n rollouts).""" + self._quota_used += 1 + self._in_flight += 1 + if self._quota_used >= self.config.max_rollout_quota: + self._throttle_event.clear() + + def on_group_consumed(self) -> None: + """Training loop consumed one group from the buffer.""" + self._in_flight = max(0, self._in_flight - 1) + + def on_group_filtered(self) -> None: + """Accumulator filtered out a group. Decrements in-flight count.""" + self._in_flight = max(0, self._in_flight - 1) + + async def wait_for_throttle(self) -> None: + """Generation loop blocks here when dispatch window is full.""" + await self._throttle_event.wait() + + def has_quota(self) -> bool: + """Whether the generation loop can dispatch another group.""" + return self._quota_used < self.config.max_rollout_quota + + # --- Weight sync --- + + def on_training_step_complete(self) -> None: + self._steps_since_sync += 1 + + def should_sync(self) -> bool: + return self._steps_since_sync >= self.config.trigger_parameter_sync_step + + def on_sync_complete(self) -> None: + self._weight_version += 1 + self._steps_since_sync = 0 + self._total_syncs += 1 + # Reset dispatch window. In-flight items span the sync boundary — + # they were dispatched with old weights and count toward the new window. + self._quota_used = self._in_flight + if self._quota_used < self.config.max_rollout_quota: + self._throttle_event.set() + + # --- Generation pause (for validation / weight sync if partial_rollout is False) --- + + def pause_generation(self) -> None: + self._generation_paused.clear() + + def resume_generation(self) -> None: + self._generation_paused.set() + + async def wait_for_generation_allowed(self) -> None: + await self._generation_paused.wait() + + # --- In-flight task tracking --- + + def track_task(self, task: asyncio.Task) -> None: + """Register an in-flight rollout task.""" + self._in_flight_tasks.add(task) + task.add_done_callback(self._in_flight_tasks.discard) + + async def wait_for_drain(self) -> None: + """Wait for all in-flight rollout tasks to complete.""" + while self._in_flight_tasks: + await asyncio.sleep(0.1) + + def stats(self) -> dict: + return { + "async/weight_version": self._weight_version, + "async/dispatched_since_sync": self._quota_used - self._in_flight, + "async/quota_used": self._quota_used, + "async/in_flight_groups": self._in_flight, + "async/steps_since_sync": self._steps_since_sync, + "async/max_rollout_quota": self.config.max_rollout_quota, + "async/total_syncs": self._total_syncs, + } diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index ea3a4868b..45819f7c7 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -1,5 +1,7 @@ import asyncio +import logging import time +import uuid from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Callable, Iterable @@ -9,14 +11,17 @@ import numpy as np from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset +from rllm.experimental.buffer import TrajectoryGroupBuffer from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, ) from rllm.experimental.common.config import ( + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, TransformConfig, @@ -27,15 +32,22 @@ RejectionSamplingState, apply_rejection_sampling_and_filtering, ) -from rllm.experimental.common.transform import _default_traj_grouping_hook, transform_episodes_to_trajectory_groups -from rllm.experimental.common.visualization import visualize_trajectory_last_steps +from rllm.experimental.common.transform import ( + _default_traj_grouping_hook, + transform_episodes_to_trajectory_groups, +) +from rllm.experimental.common.visualization import print_metrics_table, visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine +from rllm.experimental.metrics import MetricsAggregator from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine +from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow +logger = logging.getLogger(__name__) + @dataclass class TrainerState: @@ -46,6 +58,7 @@ class TrainerState: epoch: int = 0 total_steps: int = 0 is_training: bool = True + weight_version: int = 0 # For timing and metrics timing_dict: dict = field(default_factory=dict) metrics: dict = field(default_factory=dict) @@ -129,11 +142,24 @@ def __init__( # Extract the TrajectoryGroup-specific estimator from kwargs self.traj_group_adv_estimator_map = traj_group_adv_estimator_map or {} + # TODO(kylemontgomery1): disaggregate UnitifiedTrainer.__init__ from engine/infra setup + self.backend = backend_cls(config=config, **(backend_args or {})) self._validate_and_setup_configs() self._setup_logging() + # Async training config + async_cfg = self.rllm_config.get("async_training", {}) + self.async_config = AsyncTrainingConfig( + enable=async_cfg.get("enable", False), + mini_batch_size=async_cfg.get("mini_batch_size", 1), + fwd_bwd_group_size=async_cfg.get("fwd_bwd_group_size", 1), + staleness_threshold=async_cfg.get("staleness_threshold", 0.0), + trigger_parameter_sync_step=async_cfg.get("trigger_parameter_sync_step", 1), + partial_rollout=async_cfg.get("partial_rollout", True), + ) + rollout_engine: RolloutEngine = self.backend.init_rollout_engine( cf_config=self.cf_config, transform_config=self.transform_config, @@ -201,6 +227,7 @@ def __init__( runtime=self._remote_runtime, gateway=self._gateway, session_timeout=remote_runtime_config.session_timeout, + n_parallel_tasks=self.rllm_config.workflow.n_parallel_tasks, episode_logger=self.episode_logger, ) else: @@ -247,6 +274,7 @@ def _validate_and_setup_configs(self): mode=rs_mode, min_partial_solve_tasks=self.rllm_config.rejection_sample.min_partial_solve_tasks, min_trajs_per_group=self.rllm_config.rejection_sample.min_trajs_per_group, + filter_uniform_groups=self.rllm_config.rejection_sample.get("filter_uniform_groups", False), ) # algorithm config (used for rLLM-native advantage computation) @@ -290,6 +318,8 @@ def _setup_logging(self): # Main training loop methods # ========================================================================= + # TODO(kylemontgomery1): better seperation of on policy vs fully async training code + def fit(self): """Main training loop (sync entry point).""" asyncio.run(self.fit_async()) @@ -310,8 +340,7 @@ async def fit_async(self) -> None: await self.backend.on_train_start(trainer_state) if self.rllm_config.trainer.get("val_before_train", True): - val_metrics = await self._validate_async(trainer_state) - pprint(f"Initial validation metrics: {val_metrics}") + await self._validate_async(trainer_state) if self.rllm_config.trainer.get("val_only", False): return @@ -324,7 +353,16 @@ async def fit_async(self) -> None: await self.backend.on_train_end(trainer_state) async def _fit_async(self, trainer_state: TrainerState) -> None: - """Internal async main training loop.""" + """Dispatch to sync or concurrent training based on config.""" + # TODO(listar2000): after some benchmarking, maybe we just keep the fully-async and treat on-policy as a special case. + if self.async_config.enable: + await self._fit_fully_async(trainer_state) + else: + await self._fit_on_policy(trainer_state) + + async def _fit_on_policy(self, trainer_state: TrainerState) -> None: + """Synchronous training loop (the most vanilla, standalone case that does not support minibatching or off-policy training).""" + # TODO(kylemontgomery1): dataloader should be backend-agnostic train_dataloader: Iterable = self.backend.get_dataloader(self.train_dataset, trainer_state) break_via_total_batches = False # used to break the training loop via the `total_batches` parameter use_total_batches = self.rllm_config.trainer.get("total_batches") is not None and self.rllm_config.trainer.total_batches > 0 @@ -351,6 +389,7 @@ async def _fit_async(self, trainer_state: TrainerState) -> None: await self._train_batch_async(batch, trainer_state) await self.backend.on_batch_end(trainer_state) + print_metrics_table(trainer_state.metrics, trainer_state.global_step) self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, @@ -373,13 +412,13 @@ async def _fit_async(self, trainer_state: TrainerState) -> None: # final validation after training if self.rllm_config.trainer.test_freq > 0: - val_metrics = await self._validate_async(trainer_state) - pprint(f"Final validation metrics: {val_metrics}") + await self._validate_async(trainer_state) async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> None: """Train a batch (async implementation).""" self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=trainer_state.epoch) + # TODO(kylemontgomery1): episode generation should be backend-agnostic # stage 1: generate episodes (async) and collect metrics (sync) trainer_state.episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=False) if not trainer_state.has_episodes: @@ -413,6 +452,7 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N await self.backend.process_backend_batch(trainer_state) assert trainer_state.has_backend_batch, "Backend batch is not transformed or processed successfully" + # TODO(kylemontgomery1): compute advantages should be backend-agnostic # stage 6: compute advantages (async) await self.backend.compute_advantages(trainer_state, self.algorithm_config) @@ -435,6 +475,262 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N for r in TerminationReason: trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts + # ========================================================================= + # Fully-asynchronous training pipeline + # ========================================================================= + + async def _fit_fully_async(self, trainer_state: TrainerState) -> None: + """Fully-async generation + training with group-level streaming.""" + assert self.config.data.train_batch_size == 1, f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" + assert not getattr(self.agent_workflow_engine, "raise_on_error", False), "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" + coord_config = SyncCoordinatorConfig( + mini_batch_size=self.async_config.mini_batch_size, + group_size=self.rllm_config.rollout.n, + staleness_threshold=self.async_config.staleness_threshold, + trigger_parameter_sync_step=self.async_config.trigger_parameter_sync_step, + ) + coordinator = SyncCoordinator(coord_config) + aggregator = MetricsAggregator() + buffer = TrajectoryGroupBuffer( + group_size=self.rllm_config.rollout.n, + coordinator=coordinator, + aggregator=aggregator, + algorithm_config=self.algorithm_config, + transform_config=self.transform_config, + cf_config=self.cf_config, + rs_config=self.rs_config, + episode_offload_dir=self.async_config.episode_offload_dir, + trajectory_group_offload_dir=self.async_config.trajectory_group_offload_dir, + ) + + # Compute total_steps for LR scheduling + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + if use_total_batches: + trainer_state.total_steps = self.rllm_config.trainer.total_batches + else: + trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs + + total_tasks = len(train_dataloader) * self.rllm_config.trainer.total_epochs + pbar = tqdm(total=total_tasks, desc="Tasks", unit="task") + buffer._pbar = pbar + + try: + gen_task = asyncio.create_task(self._generation_loop(trainer_state, buffer, coordinator)) + await self._training_loop(trainer_state, buffer, coordinator, aggregator) + if not gen_task.done(): + gen_task.cancel() + try: + await gen_task + except asyncio.CancelledError: + pass + finally: + pbar.close() + + async def _generation_loop( + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, + ) -> None: + """Generate episodes and stream to TrajectoryGroupBuffer.""" + group_size = self.rllm_config.rollout.n + + try: + for epoch in range(self.rllm_config.trainer.total_epochs): + await self.backend.on_epoch_start(trainer_state) + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) + + for batch in train_dataloader: + task = batch[0] + + await coordinator.wait_for_generation_allowed() + if not coordinator.has_quota(): + await coordinator.wait_for_throttle() + coordinator.on_group_dispatched() + + task_id = str(uuid.uuid4()) + for rollout_idx in range(group_size): + + async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): + _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry(task=t, task_id=tid, rollout_idx=ridx, result_idx=0) + await buffer.add_episode(tid, episode) + + t = asyncio.create_task(_run_rollout()) + coordinator.track_task(t) + + await self.backend.on_epoch_end(trainer_state) + + await coordinator.wait_for_drain() + finally: + buffer.mark_generation_complete() + + async def _training_loop( + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + ) -> None: + """Consume task batches from buffer, run forward-backward + optimizer step.""" + mini_batch_size = self.async_config.mini_batch_size + fwd_bwd_group_size = self.async_config.fwd_bwd_group_size + num_fwd_bwd_passes = mini_batch_size // fwd_bwd_group_size + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + rollout_engine = getattr(self.agent_workflow_engine, "rollout_engine", None) + + while True: + trainer_state.reset_batch() + step_start = time.perf_counter() + weight_versions = [] + all_trajectory_groups: list[TrajectoryGroup] = [] + all_episodes: list[Episode] = [] + groups_consumed = 0 + buffer_wait_time = 0.0 + done = False + + buffered = buffer._queue.qsize() + logger.info( + f"[TrainingLoop] Step {trainer_state.global_step}: waiting for {mini_batch_size} task batches ({num_fwd_bwd_passes} fwd-bwd passes x {fwd_bwd_group_size} groups), {buffered} buffered" + ) + + # 1. Pull mini_batch_size task batches total, split into + # num_fwd_bwd_passes forward-backward passes of fwd_bwd_group_size each. + for pass_idx in range(num_fwd_bwd_passes): + chunk_groups: list[TrajectoryGroup] = [] + + for _ in range(fwd_bwd_group_size): + t_wait = time.perf_counter() + task_batch = await buffer.get() + buffer_wait_time += time.perf_counter() - t_wait + if task_batch is None: + done = True + break + + coordinator.on_group_consumed() + groups_consumed += 1 + + for group in task_batch.groups: + weight_versions.append(group.weight_version) + chunk_groups.extend(task_batch.groups) + all_trajectory_groups.extend(task_batch.groups) + all_episodes.extend(task_batch.episodes) + + if not chunk_groups or done: + break + + # Forward-backward on this chunk + trainer_state.trajectory_groups = chunk_groups + + if trainer_state.has_trajectory_groups: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: fwd-bwd pass {pass_idx + 1}/{num_fwd_bwd_passes} ({len(chunk_groups)} groups)") + await self.backend.on_batch_start(trainer_state) + trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) + await self.backend.process_backend_batch(trainer_state) + + # Drain per-chunk backend metrics into aggregator + aggregator.record_dict(trainer_state.metrics) + trainer_state.metrics = {} + + # Only run optimizer step on a full batch + if groups_consumed < mini_batch_size: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: incomplete batch ({groups_consumed}/{mini_batch_size}), stopping") + break + + # 2. Optimizer step + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: optimizer step") + await self.backend.update_policy(trainer_state) + + # 3. Capture pre-sync metrics (before weight sync resets coordinator state) + staleness_values = [coordinator.weight_version - v for v in weight_versions] + aggregator.record("async/staleness_mean", float(np.mean(staleness_values))) + aggregator.record("async/staleness_min", float(np.min(staleness_values))) + aggregator.record("async/staleness_max", float(np.max(staleness_values))) + aggregator.record("async/groups_consumed", groups_consumed) + aggregator.record("time/buffer_wait", buffer_wait_time) + pre_sync_coordinator_stats = coordinator.stats() + pre_sync_buffer_stats = buffer.stats() + + # 4. Weight sync + coordinator.on_training_step_complete() + sync_time = 0.0 + if coordinator.should_sync(): + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: triggering weight sync") + t0 = time.perf_counter() + await self._perform_weight_sync(trainer_state, coordinator, rollout_engine) + sync_time = time.perf_counter() - t0 + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: weight sync complete ({sync_time:.2f}s)") + if sync_time > 0: + aggregator.record("time/weight_sync", sync_time) + aggregator.record("time/step", time.perf_counter() - step_start) + + # Set all trajectory groups and stripped episodes for visualization/logging + trainer_state.trajectory_groups = all_trajectory_groups + trainer_state.episodes = all_episodes + + if self.tokenizer is not None and trainer_state.has_trajectory_groups: + visualize_trajectory_last_steps( + trainer_state.trajectory_groups, + tokenizer=self.tokenizer, + max_steps_to_visualize=2, + show_workflow_metadata=True, + ) + + # 5. Flush aggregator and merge pre-sync snapshots into trainer_state.metrics + trainer_state.metrics.update(aggregator.flush()) + trainer_state.metrics.update(pre_sync_buffer_stats) + trainer_state.metrics.update(pre_sync_coordinator_stats) + + # 6. Compute derived metrics + step_time = trainer_state.metrics.get("time/step", 1.0) + trainer_state.metrics["async/trainer_idle_ratio"] = buffer_wait_time / max(step_time, 1e-9) + + # 7. on_batch_end writes backend metrics (progress, optim, timing) + await self.backend.on_batch_end(trainer_state) + + # 7. Print and log + print_metrics_table(trainer_state.metrics, trainer_state.global_step) + self.logger.log( + data=trainer_state.metrics, + step=trainer_state.global_step, + episodes=trainer_state.episodes, + trajectory_groups=trainer_state.trajectory_groups, + ) + + # Periodic validation + if self.rllm_config.trainer.test_freq > 0 and trainer_state.global_step % self.rllm_config.trainer.test_freq == 0: + await self._validate_async_with_pause(trainer_state, coordinator) + + trainer_state.global_step += 1 + + if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: + break + + async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine | None) -> None: + """Synchronize weights between training and rollout engines.""" + if not self.async_config.partial_rollout: + coordinator.pause_generation() + await coordinator.wait_for_drain() + + trainer_state.weight_version = coordinator.weight_version + 1 + await self.backend.on_policy_updated(trainer_state) + if rollout_engine is not None: + rollout_engine.weight_version = trainer_state.weight_version + coordinator.on_sync_complete() + + if not self.async_config.partial_rollout: + coordinator.resume_generation() + + async def _validate_async_with_pause(self, trainer_state: TrainerState, coordinator: SyncCoordinator) -> dict: + """Validation with dispatch-level pause. Waits for workflows to drain, then runs validation.""" + coordinator.pause_generation() + await coordinator.wait_for_drain() + try: + return await self._validate_async(trainer_state) + finally: + coordinator.resume_generation() + async def _validate_async(self, trainer_state: TrainerState) -> dict: """Validate the model (async implementation).""" n_val_samples = self.rllm_config.rollout.n_val @@ -453,7 +749,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for batch in val_dataloader: # Generate episodes and transform to trajectory groups val_episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=True) - val_trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) + val_trajectory_groups, _ = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) reward_metrics = collect_reward_and_advantage_from_trajectory_groups(val_trajectory_groups, self.algorithm_config, collect_advantage=False) is_correct_lst.extend([episode.is_correct for episode in val_episodes]) @@ -466,7 +762,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for key, value in episode.metrics.items(): workflow_metrics_by_source[data_source][key].append(float(value)) - for key, value in (transform_metrics | reward_metrics).items(): + for key, value in reward_metrics.items(): val_metrics[f"val/{key}"].append(value) test_end = time.perf_counter() @@ -496,6 +792,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: # post-process the val metrics to reduce any "list values" into scalars reduce_metrics_lists(val_metrics) + print_metrics_table(val_metrics, trainer_state.global_step, title="Validation") self.logger.log(data=val_metrics, step=trainer_state.global_step) await self.backend.on_validation_end(trainer_state) return val_metrics diff --git a/rllm/experimental/verl/verl_backend.py b/rllm/experimental/verl/verl_backend.py index 6473bbca7..06bbd49e7 100644 --- a/rllm/experimental/verl/verl_backend.py +++ b/rllm/experimental/verl/verl_backend.py @@ -622,10 +622,10 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: return False else: trainer_state.is_training = False - self.rollout_engine.validate = True # type: ignore[attr-defined] + self.rollout_engine.is_validation = True return True async def on_validation_end(self, trainer_state: TrainerState) -> None: """Called at the end of validation.""" trainer_state.is_training = True - self.rollout_engine.validate = False # type: ignore[attr-defined] + self.rollout_engine.is_validation = False diff --git a/rllm/rewards/countdown_reward.py b/rllm/rewards/countdown_reward.py index ccdf6157a..093d35f95 100644 --- a/rllm/rewards/countdown_reward.py +++ b/rllm/rewards/countdown_reward.py @@ -1,9 +1,12 @@ +import logging import random import re from rllm import Action from rllm.rewards.reward_types import RewardOutput +logger = logging.getLogger(__name__) + def extract_solution(solution_str): """Extract the equation from the solution string.""" @@ -72,20 +75,17 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, do_print = random.randint(1, 64) == 1 if do_print: - print("--------------------------------") - print(f"Target: {target} | Numbers: {numbers}") - print(f"Extracted equation: {equation}") - print(f"Solution string: {solution_str}") + logger.debug(f"Target: {target} | Numbers: {numbers} | Equation: {equation} | Solution: {solution_str}") if equation is None: if do_print: - print("No equation found") + logger.debug("No equation found") return 0 # Validate equation uses correct numbers if not validate_equation(equation, numbers): if do_print: - print("Invalid equation") + logger.debug("Invalid equation") return format_score # Evaluate equation @@ -93,20 +93,20 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, result = evaluate_equation(equation) if result is None: if do_print: - print("Could not evaluate equation") + logger.debug("Could not evaluate equation") return format_score if abs(result - target) < 1e-5: # Account for floating point precision if do_print: - print(f"Correct equation: {equation} = {result}") + logger.debug(f"Correct equation: {equation} = {result}") return score else: if do_print: - print(f"Wrong result: equation = {result}, target = {target}") + logger.debug(f"Wrong result: equation = {result}, target = {target}") return format_score except Exception: if do_print: - print("Error evaluating equation") + logger.debug("Error evaluating equation") return format_score diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 29ea2793e..00ffd46ef 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -27,7 +27,6 @@ from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.trainer.tinker.tinker_metrics_utils import ( - print_metrics_table, update_training_metrics, ) from rllm.trainer.tinker.tinker_policy_trainer import TinkerPolicyTrainer @@ -98,6 +97,9 @@ def __init__( # Store algorithm config for use in process_backend_batch self._algorithm_config: AlgorithmConfig | None = None + # Track whether on_policy_updated was called this step (for backward compat) + self._policy_updated_this_step: bool = False + # Specific optimizer parameters for Tinker self.learning_rate = self.full_config.training.get("learning_rate", 1e-6) self.beta1 = self.full_config.training.get("beta1", 0.9) @@ -189,6 +191,8 @@ def get_dataloader(self, dataset: Dataset | None, trainer_state: TrainerState) - shuffle = True else: batch_size = self.full_config.data.get("val_batch_size", self.full_config.data.train_batch_size) + if batch_size == -1: + batch_size = len(dataset) shuffle = False return torch.utils.data.DataLoader( @@ -404,6 +408,9 @@ async def on_train_start(self, trainer_state: TrainerState) -> None: resume = bool(self.full_config.training.resume_from_tinker_id) start_batch, self.sampling_client = await self.policy_trainer.initialize_async(resume_from_checkpoint=resume) + # Propagate sampling_client to rollout engine so it can make inference calls + self.rollout_engine.set_sampling_client(self.sampling_client) + # Update trainer state with the start batch from checkpoint trainer_state.global_step = start_batch @@ -416,29 +423,38 @@ async def on_train_end(self, trainer_state: TrainerState) -> None: logger.info(f"Saving final checkpoint at step {trainer_state.global_step}") await self.policy_trainer.save_checkpoint_and_get_sampling_client(trainer_state.global_step, kind="both", do_save=True) + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Save checkpoint and update sampling_client after policy update.""" + assert self.policy_trainer is not None, "policy_trainer is not initialized" + self._policy_updated_this_step = True + + global_step = trainer_state.global_step + save_freq = self.full_config.rllm.trainer.save_freq + do_save = save_freq > 0 and global_step % save_freq == 0 + self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + + # Propagate updated sampling_client to rollout engine for async weight sync + self.rollout_engine.set_sampling_client(self.sampling_client) + async def on_batch_end(self, trainer_state: TrainerState) -> None: """Called at the end of each batch. - Saves checkpoint, updates sampling client, and prints metrics. + In sync mode, on_policy_updated() is not called separately, so we + do the checkpoint/sampling_client update here for backward compat. """ assert self.policy_trainer is not None, "policy_trainer is not initialized" - global_step = trainer_state.global_step - # Save sampler checkpoint after each batch - with simple_timer("save_checkpoint", trainer_state.timing_dict): - logger.info(f"Saving state checkpoint and sampler at step {global_step}") - save_freq = self.full_config.rllm.trainer.save_freq - do_save = save_freq > 0 and global_step % save_freq == 0 - self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + # If on_policy_updated() wasn't called (sync mode), do checkpoint here + if not self._policy_updated_this_step: + with simple_timer("save_checkpoint", trainer_state.timing_dict): + logger.info(f"Saving state checkpoint and sampler at step {trainer_state.global_step}") + await self.on_policy_updated(trainer_state) + self._policy_updated_this_step = False # Update metrics learning_rate = trainer_state.extra_info.get("scheduled_learning_rate", self.learning_rate) update_training_metrics(trainer_state, learning_rate, trainer_state.total_steps) - # Print metrics table - if trainer_state.metrics: - print_metrics_table(trainer_state.metrics, global_step) - async def on_epoch_start(self, trainer_state: TrainerState) -> None: """Called at the start of an epoch.""" logger.info(f"Starting epoch {trainer_state.epoch}") diff --git a/rllm/trainer/tinker/tinker_metrics_utils.py b/rllm/trainer/tinker/tinker_metrics_utils.py index 3976081f3..618a6ce30 100644 --- a/rllm/trainer/tinker/tinker_metrics_utils.py +++ b/rllm/trainer/tinker/tinker_metrics_utils.py @@ -5,61 +5,12 @@ import tinker import torch +from rllm.experimental.common.visualization import print_metrics_table # noqa: F401 (re-export) from rllm.experimental.unified_trainer import TrainerState logger = logging.getLogger(__name__) -def print_metrics_table(metrics: dict, step: int): - """ - Print metrics as a formatted table (similar to tinker_cookbook). - - Args: - metrics: Dictionary of metrics - step: Current step number - """ - try: - from rich.console import Console - from rich.table import Table - - console = Console() - - # Create table - table = Table(title=f"Step {step}", show_header=True, header_style="bold magenta") - table.add_column("Metric", style="cyan", no_wrap=False) - table.add_column("Value", justify="right", style="green") - - # Sort metrics by key for consistent ordering - sorted_metrics = sorted(metrics.items()) - - for key, value in sorted_metrics: - # Format value based on type - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - - table.add_row(key, value_str) - - console.print(table) - - except ImportError: - # Fallback to simple text table if rich is not available - print(f"\nStep {step}") - print("=" * 60) - for key, value in sorted(metrics.items()): - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - print(f"{key:40s} {value_str:>15s}") - print("=" * 60) - - def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training_logprobs: list[torch.Tensor]) -> dict: """ Compute KL divergence and entropy metrics from training. @@ -102,10 +53,10 @@ def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training perplexity = torch.exp(torch.tensor(entropy_sample)).item() return { - "optim/kl_sample_train_v1": kl_sample_train_v1, - "optim/kl_sample_train_v2": kl_sample_train_v2, - "optim/entropy": entropy_sample, - "optim/perplexity": perplexity, + "train/kl_sample_train_v1": kl_sample_train_v1, + "train/kl_sample_train_v2": kl_sample_train_v2, + "train/entropy": entropy_sample, + "train/perplexity": perplexity, } @@ -125,7 +76,7 @@ def update_training_metrics(trainer_state: TrainerState, learning_rate: float, t { "progress/batch": trainer_state.global_step, "progress/epoch": trainer_state.epoch, - "optim/lr": learning_rate, + "progress/lr": learning_rate, } ) diff --git a/rllm/trainer/tinker/tinker_policy_trainer.py b/rllm/trainer/tinker/tinker_policy_trainer.py index 57d7bd6ca..30f1caa6f 100644 --- a/rllm/trainer/tinker/tinker_policy_trainer.py +++ b/rllm/trainer/tinker/tinker_policy_trainer.py @@ -255,12 +255,18 @@ async def forward_backward_from_trajectory_groups( # Wait for completion and extract logprobs fwd_bwd_results = await asyncio.gather(*fwd_bwd_futures) - # Extract training logprobs from loss_fn_outputs + # Extract training logprobs and server-side metrics from results training_logprobs = [] for fwd_bwd_result in fwd_bwd_results: for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + # Capture server-side metrics (e.g. loss) under train/ prefix + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics @@ -335,6 +341,11 @@ async def fused_forward_backward_and_optim_step( for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics, scheduled_learning_rate diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index f758b7ce1..13497ab63 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -36,7 +36,7 @@ def _flatten_token_input(token_input: TinkerTokenInput) -> TinkerTokenInput: return flattened -def trajectory_to_datums(traj: Trajectory) -> list[tinker.Datum]: +def trajectory_to_datums(traj: Trajectory, router_replay: bool = False) -> list[tinker.Datum]: """ Return one or more Datum objects corresponding to the trajectory. If the sequence grows by appending, i.e., each successive observation contains @@ -61,6 +61,7 @@ class SequenceAccumulator: sampled_logprobs: list[float] = [] advantages: list[float] = [] mask: list[float] = [] + routing_matrices: list[str] = [] @classmethod def clear(cls): @@ -68,6 +69,7 @@ def clear(cls): cls.sampled_logprobs = [] cls.advantages = [] cls.mask = [] + cls.routing_matrices = [] def make_datum_from_state(): all_tokens_T = _flat_token_input_to_model_input(SequenceAccumulator.full_sequence) @@ -77,6 +79,9 @@ def make_datum_from_state(): advantages_T = SequenceAccumulator.advantages[1:] mask_T = SequenceAccumulator.mask[1:] assert input_tokens_T.length == len(target_tokens_T) == len(sampled_logprobs_T) == len(advantages_T) == len(mask_T) + if router_replay and SequenceAccumulator.routing_matrices: + rm_shifted = SequenceAccumulator.routing_matrices[1:] # match rightshift + input_tokens_T = input_tokens_T.model_copy(update={"routing_matrices": rm_shifted}) return tinker.Datum( model_input=input_tokens_T, loss_fn_inputs={ @@ -118,6 +123,9 @@ def make_datum_from_state(): SequenceAccumulator.sampled_logprobs.extend([0.0] * delta_token_input_length + output_logprobs) SequenceAccumulator.advantages.extend([0] * delta_token_input_length + advantages) SequenceAccumulator.mask.extend([0.0] * delta_token_input_length + [1.0] * len(output_token_ids)) + if router_replay: + step_rm = step.routing_matrices or [] + SequenceAccumulator.routing_matrices.extend([""] * delta_token_input_length + (list(step_rm) if step_rm else [""] * len(output_token_ids))) if SequenceAccumulator.full_sequence: data.append(make_datum_from_state()) @@ -137,9 +145,12 @@ def transform_trajectory_groups_to_datums( If the `estimator_map` is used in the algorithm config, we return a dictionary of datums, keyed by the trajectory group role. Otherwise, we return a list of datums. """ - # step 1: compute the advantages for each group using the common functionality - # this fills the `advantage` attribute of all the steps in the trajectory groups - adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) + # step 1: compute advantages (skip if already pre-computed by buffer) + has_advantages = any(step.advantage is not None for group in trajectory_groups for traj in group.trajectories for step in traj.steps) + if has_advantages: + adv_metrics = {} + else: + adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) if algorithm_config.estimator_map: datums_dict = defaultdict(list) @@ -147,11 +158,27 @@ def transform_trajectory_groups_to_datums( datums = [] # step 2: iterate over all steps and build the Tinker Datum objects + seqs_per_traj = [] + seq_lengths = [] for group in trajectory_groups: for trajectory in group.trajectories: + traj_datums = trajectory_to_datums(trajectory, router_replay=algorithm_config.router_replay) + seqs_per_traj.append(len(traj_datums)) + for d in traj_datums: + seq_lengths.append(d.model_input.length) if algorithm_config.estimator_map: - datums_dict[group.group_role].extend(trajectory_to_datums(trajectory)) + datums_dict[group.group_role].extend(traj_datums) else: - datums.extend(trajectory_to_datums(trajectory)) + datums.extend(traj_datums) + + if seqs_per_traj: + import numpy as _np + + adv_metrics["batch/seqs_per_traj/mean"] = _np.mean(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/min"] = _np.min(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/max"] = _np.max(seqs_per_traj) + adv_metrics["batch/seq_length/mean"] = _np.mean(seq_lengths) + adv_metrics["batch/seq_length/min"] = _np.min(seq_lengths) + adv_metrics["batch/seq_length/max"] = _np.max(seq_lengths) return (datums if not algorithm_config.estimator_map else datums_dict), adv_metrics diff --git a/rllm/workflows/workflow.py b/rllm/workflows/workflow.py index d427c5648..dd4f092b9 100644 --- a/rllm/workflows/workflow.py +++ b/rllm/workflows/workflow.py @@ -197,7 +197,7 @@ def assign_episode_correctness(self, episode: Episode) -> None: """ total_reward = 0 for trajectory in episode.trajectories: - total_reward += trajectory.reward + total_reward += trajectory.reward or 0 episode.is_correct = total_reward > 0 def collect_metrics(self, episode: Episode) -> None: @@ -210,7 +210,7 @@ def collect_metrics(self, episode: Episode) -> None: metrics = defaultdict(list) for traj in episode.trajectories: name = traj.name - metrics[name].append(traj.reward) + metrics[name].append(traj.reward or 0.0) episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()} def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode: