From f46aa5462102060b177b0bd49d8c9b8527001884 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:57:12 -0600 Subject: [PATCH 01/21] init new feature on unified fully async design --- rllm/experimental/async_config.py | 26 ++++ rllm/experimental/config/rllm/base.yaml | 7 + rllm/experimental/experience_buffer.py | 111 +++++++++++++++ rllm/experimental/protocol.py | 9 ++ rllm/experimental/unified_trainer.py | 177 +++++++++++++++++++++++- rllm/trainer/tinker/tinker_backend.py | 31 +++-- 6 files changed, 351 insertions(+), 10 deletions(-) create mode 100644 rllm/experimental/async_config.py create mode 100644 rllm/experimental/experience_buffer.py diff --git a/rllm/experimental/async_config.py b/rllm/experimental/async_config.py new file mode 100644 index 000000000..df4a46c38 --- /dev/null +++ b/rllm/experimental/async_config.py @@ -0,0 +1,26 @@ +"""Configuration for async (concurrent generation + training) mode.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class AsyncTrainingConfig: + """Controls the async training behavior spectrum. + + When `enabled` is False, the trainer uses the current synchronous pipeline. + When `enabled` is True, the trainer runs concurrent generation + training + with staleness-based filtering. + + Behavior spectrum: + - max_staleness=0, buffer_size=1: Effectively synchronous (backpressure serializes) + - max_staleness=1, buffer_size=2: 1-step overlap + - max_staleness=k, buffer_size=k+1: k-step off-policy + - max_staleness=5, buffer_size=8: Fully async with aggressive filtering + """ + + enabled: bool = False + max_staleness: int = 0 + buffer_size: int = 1 + requeue_stale: bool = True diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index 66e2b073e..07eb4c459 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -107,6 +107,13 @@ sdk: mode: subprocess # external | subprocess admin_token: my-shared-secret +# Async Training Configuration +async_training: + enabled: false # When false, use current synchronous pipeline + max_staleness: 0 # Max policy version gap (0 = effectively synchronous) + buffer_size: 1 # Max buffered experiences + requeue_stale: true # Requeue stale batches for regeneration + # Episode Logging Configuration episode_logging: log_episodes: false diff --git a/rllm/experimental/experience_buffer.py b/rllm/experimental/experience_buffer.py new file mode 100644 index 000000000..cc58c1628 --- /dev/null +++ b/rllm/experimental/experience_buffer.py @@ -0,0 +1,111 @@ +"""Experience buffer protocol and asyncio implementation for async training.""" + +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from rllm.agents.agent import Episode, TrajectoryGroup + + +@dataclass +class BufferedExperience: + """A unit of experience stored in the buffer.""" + + trajectory_groups: list[TrajectoryGroup] + episodes: list[Episode] + policy_version: int + batch_source: Any # Original batch (for requeuing) + + +class ExperienceBufferProtocol(ABC): + """Abstract base class for experience buffers. + + Different backends can provide different implementations: + - AsyncioExperienceBuffer: Single-threaded asyncio.Queue for Tinker + - RayExperienceBuffer (future): Ray actor for multi-process Verl + """ + + @abstractmethod + async def put(self, experience: BufferedExperience) -> None: + """Add experience. Blocks if buffer full (backpressure).""" + + @abstractmethod + async def get(self, current_policy_version: int) -> BufferedExperience | None: + """Get next non-stale experience. Returns None when done.""" + + @abstractmethod + def mark_generation_complete(self) -> None: + """Signal that generation is finished.""" + + @abstractmethod + async def get_requeue_batch(self) -> Any | None: + """Get a stale batch to regenerate, or None.""" + + @abstractmethod + def stats(self) -> dict: + """Buffer statistics for metrics.""" + + +class AsyncioExperienceBuffer(ExperienceBufferProtocol): + """Single-threaded asyncio-based buffer for Tinker backend. + + Uses asyncio.Queue for backpressure. Tinker's compute happens on remote + servers, so the Python process only orchestrates — no threading needed. + """ + + def __init__(self, max_size: int, max_staleness: int, requeue_stale: bool): + self._queue: asyncio.Queue[BufferedExperience | None] = asyncio.Queue(maxsize=max_size) + self._requeue_queue: asyncio.Queue = asyncio.Queue() + self._max_staleness = max_staleness + self._requeue_stale = requeue_stale + self._generation_complete = False + # Stats + self._total_produced = 0 + self._total_consumed = 0 + self._total_discarded = 0 + + async def put(self, experience: BufferedExperience) -> None: + await self._queue.put(experience) + self._total_produced += 1 + + async def get(self, current_policy_version: int) -> BufferedExperience | None: + while True: + if self._generation_complete and self._queue.empty(): + return None + try: + experience = await asyncio.wait_for(self._queue.get(), timeout=1.0) + except asyncio.TimeoutError: + if self._generation_complete and self._queue.empty(): + return None + continue + if experience is None: # sentinel + return None + # Staleness check + version_gap = current_policy_version - experience.policy_version + if version_gap > self._max_staleness: + self._total_discarded += 1 + if self._requeue_stale: + await self._requeue_queue.put(experience.batch_source) + continue + self._total_consumed += 1 + return experience + + def mark_generation_complete(self) -> None: + self._generation_complete = True + + async def get_requeue_batch(self) -> Any | None: + try: + return self._requeue_queue.get_nowait() + except asyncio.QueueEmpty: + return None + + def stats(self) -> dict: + return { + "async/buffer_size": self._queue.qsize(), + "async/total_produced": self._total_produced, + "async/total_consumed": self._total_consumed, + "async/total_discarded": self._total_discarded, + } diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 0c8491dbe..1e1027705 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -214,6 +214,15 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: trainer_state.is_training = False return True + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Hook called immediately after update_policy(). Backends sync weights here. + + For Tinker: save checkpoint, create new sampling_client. + For Verl (future): trigger NCCL sync to rollout workers. + Default: no-op (sync mode uses on_batch_end for this). + """ + pass + async def on_validation_end(self, trainer_state: TrainerState) -> None: """Hook method called at the end of validation.""" trainer_state.is_training = True diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 2d1521362..b3d84f6f4 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -12,6 +12,7 @@ from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset +from rllm.experimental.async_config import AsyncTrainingConfig from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, @@ -30,6 +31,7 @@ from rllm.experimental.common.transform import _default_traj_grouping_hook, transform_episodes_to_trajectory_groups from rllm.experimental.common.visualization import visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine +from rllm.experimental.experience_buffer import AsyncioExperienceBuffer, BufferedExperience, ExperienceBufferProtocol from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata @@ -45,6 +47,7 @@ class TrainerState: epoch: int = 0 total_steps: int = 0 is_training: bool = True + policy_version: int = 0 # For timing and metrics timing_dict: dict = field(default_factory=dict) metrics: dict = field(default_factory=dict) @@ -123,6 +126,15 @@ def __init__( self._validate_and_setup_configs() self._setup_logging() + # Async training config + async_cfg = self.rllm_config.get("async_training", {}) + self.async_config = AsyncTrainingConfig( + enabled=async_cfg.get("enabled", False), + max_staleness=async_cfg.get("max_staleness", 0), + buffer_size=async_cfg.get("buffer_size", 1), + requeue_stale=async_cfg.get("requeue_stale", True), + ) + rollout_engine: RolloutEngine = self.backend.init_rollout_engine( cf_config=self.cf_config, transform_config=self.transform_config, @@ -242,7 +254,14 @@ async def fit_async(self) -> None: await self.backend.on_train_end(trainer_state) async def _fit_async(self, trainer_state: TrainerState) -> None: - """Internal async main training loop.""" + """Dispatch to sync or concurrent training based on config.""" + if self.async_config.enabled: + await self._fit_async_concurrent(trainer_state) + else: + await self._fit_sync(trainer_state) + + async def _fit_sync(self, trainer_state: TrainerState) -> None: + """Synchronous training loop (original behavior).""" train_dataloader: Iterable = self.backend.get_dataloader(self.train_dataset, trainer_state) break_via_total_batches = False # used to break the training loop via the `total_batches` parameter use_total_batches = self.rllm_config.trainer.get("total_batches") is not None and self.rllm_config.trainer.total_batches > 0 @@ -353,6 +372,162 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N for r in TerminationReason: trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts + # ========================================================================= + # Concurrent (async) training methods + # ========================================================================= + + async def _fit_async_concurrent(self, trainer_state: TrainerState) -> None: + """Concurrent generation + training with experience buffer.""" + buffer = AsyncioExperienceBuffer( + max_size=self.async_config.buffer_size, + max_staleness=self.async_config.max_staleness, + requeue_stale=self.async_config.requeue_stale, + ) + + # Compute total_steps for LR scheduling + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + if use_total_batches: + trainer_state.total_steps = self.rllm_config.trainer.total_batches + else: + trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs + + await asyncio.gather( + self._generation_loop(trainer_state, buffer), + self._training_loop(trainer_state, buffer), + ) + + async def _generation_loop(self, trainer_state: TrainerState, buffer: ExperienceBufferProtocol) -> None: + """Generate episodes and push to buffer. Runs concurrently with training.""" + try: + for epoch in range(self.rllm_config.trainer.total_epochs): + train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + for batch in train_dataloader: + # Check for requeued stale batches first + requeue_batch = await buffer.get_requeue_batch() + if requeue_batch is not None: + batch = requeue_batch + + # Snapshot current policy version BEFORE generation starts + gen_policy_version = trainer_state.policy_version + + # Set training step metadata on workflow engine + self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) + + # Stage 1: Generate episodes (async) + episodes = await self.backend.generate_episodes( + batch, + agent_workflow_engine=self.agent_workflow_engine, + is_validation=False, + ) + if not episodes: + continue + + # Stage 2: Transform to trajectory groups (sync) + trajectory_groups, _ = transform_episodes_to_trajectory_groups( + episodes, + self.transform_config, + self.cf_config, + traj_grouping_hook=self.traj_grouping_hook, + ) + + # Stage 3: Rejection sampling (sync) + filtered_groups, filtered_episodes, _ = apply_rejection_sampling_and_filtering( + episodes, + trajectory_groups, + self.rs_config, + RejectionSamplingState(), + ) + if not filtered_groups: + continue + + # Push to buffer (blocks if full = backpressure) + experience = BufferedExperience( + trajectory_groups=filtered_groups, + episodes=filtered_episodes, + policy_version=gen_policy_version, + batch_source=batch, + ) + await buffer.put(experience) + + # Check total_batches limit + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: + break + finally: + buffer.mark_generation_complete() + + async def _training_loop(self, trainer_state: TrainerState, buffer: ExperienceBufferProtocol) -> None: + """Consume from buffer and train. Runs concurrently with generation.""" + while True: + experience = await buffer.get(current_policy_version=trainer_state.policy_version) + if experience is None: + break # Generation complete and buffer drained + + # Load experience into trainer_state + trainer_state.reset_batch() + trainer_state.trajectory_groups = experience.trajectory_groups + trainer_state.episodes = experience.episodes + + # Collect workflow metrics + workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(trainer_state.episodes) + + # Stages 4-7: Backend training pipeline + await self.backend.on_batch_start(trainer_state) + + # Stage 4: Transform to backend batch + trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) + + # Stage 5: Process backend batch (forward pass) + await self.backend.process_backend_batch(trainer_state) + + # Stage 6: Compute advantages + await self.backend.compute_advantages(trainer_state, self.algorithm_config) + + # Stage 7: Update policy + await self.backend.update_policy(trainer_state) + + # Increment policy version (AFTER update, BEFORE on_policy_updated) + trainer_state.policy_version += 1 + + # Notify backend of policy update (sampling_client refresh for Tinker) + await self.backend.on_policy_updated(trainer_state) + + # Stage 8: Logging, visualization, metrics + trainer_state.metrics.update(buffer.stats()) + trainer_state.metrics["async/experience_staleness"] = trainer_state.policy_version - 1 - experience.policy_version + + # Visualization + if self.tokenizer is not None: + visualize_trajectory_last_steps( + trainer_state.trajectory_groups, + tokenizer=self.tokenizer, + max_steps_to_visualize=2, + show_workflow_metadata=True, + ) + + for key, value in workflow_metrics.items(): + trainer_state.metrics[f"batch/{key}"] = np.mean(value) + + total_counts = max(sum(termination_counts.values()), 1) + for r in TerminationReason: + trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts + + await self.backend.on_batch_end(trainer_state) + + self.logger.log( + data=trainer_state.metrics, + step=trainer_state.global_step, + episodes=trainer_state.episodes, + trajectory_groups=trainer_state.trajectory_groups, + ) + + # Periodic validation + if self.rllm_config.trainer.test_freq > 0 and trainer_state.global_step % self.rllm_config.trainer.test_freq == 0: + await self._validate_async(trainer_state) + + trainer_state.global_step += 1 + async def _validate_async(self, trainer_state: TrainerState) -> dict: """Validate the model (async implementation).""" n_val_samples = self.rllm_config.rollout.n_val diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 51385bcd7..3738c3d51 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -98,6 +98,9 @@ def __init__( # Store algorithm config for use in process_backend_batch self._algorithm_config: AlgorithmConfig | None = None + # Track whether on_policy_updated was called this step (for backward compat) + self._policy_updated_this_step: bool = False + # Specific optimizer parameters for Tinker self.learning_rate = self.full_config.training.get("learning_rate", 1e-6) self.beta1 = self.full_config.training.get("beta1", 0.9) @@ -383,20 +386,30 @@ async def on_train_end(self, trainer_state: TrainerState) -> None: logger.info(f"Saving final checkpoint at step {trainer_state.global_step}") await self.policy_trainer.save_checkpoint_and_get_sampling_client(trainer_state.global_step, kind="both", do_save=True) + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Save checkpoint and update sampling_client after policy update.""" + assert self.policy_trainer is not None, "policy_trainer is not initialized" + self._policy_updated_this_step = True + + global_step = trainer_state.global_step + save_freq = self.full_config.rllm.trainer.save_freq + do_save = save_freq > 0 and global_step % save_freq == 0 + self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + async def on_batch_end(self, trainer_state: TrainerState) -> None: """Called at the end of each batch. - Saves checkpoint, updates sampling client, and prints metrics. + In sync mode, on_policy_updated() is not called separately, so we + do the checkpoint/sampling_client update here for backward compat. """ assert self.policy_trainer is not None, "policy_trainer is not initialized" - global_step = trainer_state.global_step - # Save sampler checkpoint after each batch - with simple_timer("save_checkpoint", trainer_state.timing_dict): - logger.info(f"Saving state checkpoint and sampler at step {global_step}") - save_freq = self.full_config.rllm.trainer.save_freq - do_save = save_freq > 0 and global_step % save_freq == 0 - self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + # If on_policy_updated() wasn't called (sync mode), do checkpoint here + if not self._policy_updated_this_step: + with simple_timer("save_checkpoint", trainer_state.timing_dict): + logger.info(f"Saving state checkpoint and sampler at step {trainer_state.global_step}") + await self.on_policy_updated(trainer_state) + self._policy_updated_this_step = False # Update metrics learning_rate = trainer_state.extra_info.get("scheduled_learning_rate", self.learning_rate) @@ -404,7 +417,7 @@ async def on_batch_end(self, trainer_state: TrainerState) -> None: # Print metrics table if trainer_state.metrics: - print_metrics_table(trainer_state.metrics, global_step) + print_metrics_table(trainer_state.metrics, trainer_state.global_step) async def on_epoch_start(self, trainer_state: TrainerState) -> None: """Called at the start of an epoch.""" From fd69d8f33a2a1a675cbba4ee6d0bf14d6b8472d3 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Sun, 8 Mar 2026 00:38:03 -0600 Subject: [PATCH 02/21] add coordinator control and refactor queue --- rllm/experimental/async_config.py | 26 -- rllm/experimental/common/__init__.py | 2 + rllm/experimental/common/config.py | 23 ++ rllm/experimental/config/rllm/base.yaml | 10 +- .../engine/unified_workflow_engine.py | 45 ++++ rllm/experimental/episode_buffer.py | 101 ++++++++ rllm/experimental/experience_buffer.py | 111 --------- rllm/experimental/protocol.py | 28 ++- rllm/experimental/sync_coordinator.py | 113 +++++++++ rllm/experimental/unified_trainer.py | 229 +++++++++++------- rllm/trainer/tinker/tinker_backend.py | 23 ++ 11 files changed, 480 insertions(+), 231 deletions(-) delete mode 100644 rllm/experimental/async_config.py create mode 100644 rllm/experimental/episode_buffer.py delete mode 100644 rllm/experimental/experience_buffer.py create mode 100644 rllm/experimental/sync_coordinator.py diff --git a/rllm/experimental/async_config.py b/rllm/experimental/async_config.py deleted file mode 100644 index df4a46c38..000000000 --- a/rllm/experimental/async_config.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Configuration for async (concurrent generation + training) mode.""" - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class AsyncTrainingConfig: - """Controls the async training behavior spectrum. - - When `enabled` is False, the trainer uses the current synchronous pipeline. - When `enabled` is True, the trainer runs concurrent generation + training - with staleness-based filtering. - - Behavior spectrum: - - max_staleness=0, buffer_size=1: Effectively synchronous (backpressure serializes) - - max_staleness=1, buffer_size=2: 1-step overlap - - max_staleness=k, buffer_size=k+1: k-step off-policy - - max_staleness=5, buffer_size=8: Fully async with aggressive filtering - """ - - enabled: bool = False - max_staleness: int = 0 - buffer_size: int = 1 - requeue_stale: bool = True diff --git a/rllm/experimental/common/__init__.py b/rllm/experimental/common/__init__.py index ed169b372..b262590ea 100644 --- a/rllm/experimental/common/__init__.py +++ b/rllm/experimental/common/__init__.py @@ -7,6 +7,7 @@ from rllm.experimental.common.advantage import collect_reward_and_advantage_from_trajectory_groups from rllm.experimental.common.config import ( AlgorithmConfig, + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, TransformConfig, @@ -24,6 +25,7 @@ __all__ = [ # Config + "AsyncTrainingConfig", "CompactFilteringConfig", "RejectionSamplingConfig", "TransformConfig", diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 3adff6e04..617e36797 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -8,6 +8,29 @@ from rllm.workflows.workflow import TerminationReason +@dataclass +class AsyncTrainingConfig: + """Controls the async training behavior spectrum. + + When `enabled` is False, the trainer uses the current synchronous pipeline. + When `enabled` is True, the trainer runs concurrent generation + training + with episode-level streaming and staleness-based filtering. + + Behavior spectrum (following the Verl fully-async pattern): + - staleness_threshold=0, trigger_parameter_sync_step=1: On-policy (panel a) + - staleness_threshold=0, trigger_parameter_sync_step=K: Stream off-policy (panel b) + - staleness_threshold>0, partial_rollout=False: Async with staleness (panel c) + - staleness_threshold>0, partial_rollout=True: Async with partial rollout (panel d) + """ + + enabled: bool = False + staleness_threshold: float = 0.0 # 0.0 = on-policy. Fraction of extra samples allowed. + trigger_parameter_sync_step: int = 1 # gradient updates between weight syncs + partial_rollout: bool = True # True = don't wait for in-flight to finish at sync + num_minibatches: int = 1 # gradient accumulation within a training step + requeue_stale: bool = True # re-schedule stale episodes' tasks for generation + + @dataclass class CompactFilteringConfig: """Configuration for compact filtering of episodes based on termination reasons. diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index 07eb4c459..2ad1fb171 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -109,10 +109,12 @@ sdk: # Async Training Configuration async_training: - enabled: false # When false, use current synchronous pipeline - max_staleness: 0 # Max policy version gap (0 = effectively synchronous) - buffer_size: 1 # Max buffered experiences - requeue_stale: true # Requeue stale batches for regeneration + enabled: false # When false, use current synchronous pipeline + staleness_threshold: 0.0 # 0.0 = on-policy. Fraction of extra samples allowed. + trigger_parameter_sync_step: 1 # gradient updates between weight syncs + partial_rollout: true # True = don't wait for in-flight to finish at sync + num_minibatches: 1 # gradient accumulation within a training step + requeue_stale: true # re-schedule stale episodes' tasks for generation # Episode Logging Configuration episode_logging: diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 8f2e7e80c..54baea048 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -220,6 +220,51 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No return ordered_results + async def execute_tasks_streaming( + self, + tasks: list[dict], + task_ids: list[str] | None = None, + queue: asyncio.Queue | None = None, + is_validation: bool = False, + **kwargs, + ) -> None: + """Run async workflow execution, pushing each completed episode to queue immediately. + + Concurrency is bounded by the existing workflow_queue (acts as semaphore). + No episode logging, no tqdm — designed for the fully-async training path. + + Each completed episode is pushed as a tuple: (task_id, rollout_idx, result_idx, episode). + + Args: + tasks: List of task dictionaries to process. + task_ids: Optional list of task identifiers. If None, UUIDs are generated. + queue: asyncio.Queue to push completed episodes into. + is_validation: Whether the generation is for validation. + **kwargs: Additional arguments passed to individual task processing. + """ + assert queue is not None, "queue must be provided for streaming execution" + if self.workflow_queue is None: + await self.initialize_pool() + + self.rollout_engine.is_validation = is_validation + + if task_ids is None: + task_ids = [str(uuid.uuid4()) for _ in tasks] + + task_id_counter = defaultdict(int) + + async def _process_and_push(task, task_id, rollout_idx, result_idx): + result = await self.process_task_with_retry(task, task_id, rollout_idx, result_idx, **kwargs) + await queue.put(result) + + tasks_to_run = [] + for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): + rollout_idx = task_id_counter[task_id] + tasks_to_run.append(_process_and_push(task, task_id, rollout_idx, idx)) + task_id_counter[task_id] += 1 + + await asyncio.gather(*tasks_to_run) + # TODO(listar2000): eventually the agent_workflow_engine should be backend agnostic. async def execute_tasks_verl(self, batch: DataProto, is_validation: bool = False, **kwargs) -> list[Episode]: """Execute tasks from a Verl DataProto batch and return results. diff --git a/rllm/experimental/episode_buffer.py b/rllm/experimental/episode_buffer.py new file mode 100644 index 000000000..32f1b5e8b --- /dev/null +++ b/rllm/experimental/episode_buffer.py @@ -0,0 +1,101 @@ +"""Episode buffer protocol and asyncio implementation for async training. + +The buffer is a dumb pipe — no staleness filtering. Staleness is checked +at consumption time by the training loop. Backpressure is managed externally +by SyncCoordinator's rollout quota. +""" + +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from rllm.agents.agent import Episode + + +@dataclass +class BufferedEpisode: + """A single episode stored in the buffer.""" + + episode: Episode + policy_version: int + task: dict # original task dict (for requeuing) + task_id: str # denormalized from episode.task_id + + +class EpisodeBufferProtocol(ABC): + """Abstract base class for episode buffers. + + Different backends can provide different implementations: + - AsyncioEpisodeBuffer: Single-threaded asyncio.Queue for Tinker + - RayEpisodeBuffer (future): Ray actor for multi-process Verl + """ + + @abstractmethod + async def put(self, item: BufferedEpisode) -> None: + """Add an episode to the buffer.""" + + @abstractmethod + async def get(self) -> BufferedEpisode | None: + """Get next episode. Returns None when generation is done and buffer is empty.""" + + @abstractmethod + def mark_generation_complete(self) -> None: + """Signal that generation is finished.""" + + @abstractmethod + def qsize(self) -> int: + """Current number of episodes in the buffer.""" + + @abstractmethod + def stats(self) -> dict: + """Buffer statistics for metrics.""" + + +class AsyncioEpisodeBuffer(EpisodeBufferProtocol): + """Unbounded asyncio.Queue-based buffer for Tinker backend. + + No staleness filtering — that happens at consumption time. + No max queue size — quota controls growth externally via SyncCoordinator. + Tinker's compute happens on remote servers, so the Python process only + orchestrates — no threading needed. + """ + + def __init__(self): + self._queue: asyncio.Queue[BufferedEpisode | None] = asyncio.Queue() # unbounded + self._generation_complete = False + self._total_produced = 0 + self._total_consumed = 0 + + async def put(self, item: BufferedEpisode) -> None: + await self._queue.put(item) + self._total_produced += 1 + + async def get(self) -> BufferedEpisode | None: + while True: + if self._generation_complete and self._queue.empty(): + return None + try: + item = await asyncio.wait_for(self._queue.get(), timeout=1.0) + except asyncio.TimeoutError: + if self._generation_complete and self._queue.empty(): + return None + continue + if item is None: # sentinel + return None + self._total_consumed += 1 + return item + + def mark_generation_complete(self) -> None: + self._generation_complete = True + + def qsize(self) -> int: + return self._queue.qsize() + + def stats(self) -> dict: + return { + "async/episode_buffer_size": self._queue.qsize(), + "async/total_produced": self._total_produced, + "async/total_consumed": self._total_consumed, + } diff --git a/rllm/experimental/experience_buffer.py b/rllm/experimental/experience_buffer.py deleted file mode 100644 index cc58c1628..000000000 --- a/rllm/experimental/experience_buffer.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Experience buffer protocol and asyncio implementation for async training.""" - -from __future__ import annotations - -import asyncio -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - -from rllm.agents.agent import Episode, TrajectoryGroup - - -@dataclass -class BufferedExperience: - """A unit of experience stored in the buffer.""" - - trajectory_groups: list[TrajectoryGroup] - episodes: list[Episode] - policy_version: int - batch_source: Any # Original batch (for requeuing) - - -class ExperienceBufferProtocol(ABC): - """Abstract base class for experience buffers. - - Different backends can provide different implementations: - - AsyncioExperienceBuffer: Single-threaded asyncio.Queue for Tinker - - RayExperienceBuffer (future): Ray actor for multi-process Verl - """ - - @abstractmethod - async def put(self, experience: BufferedExperience) -> None: - """Add experience. Blocks if buffer full (backpressure).""" - - @abstractmethod - async def get(self, current_policy_version: int) -> BufferedExperience | None: - """Get next non-stale experience. Returns None when done.""" - - @abstractmethod - def mark_generation_complete(self) -> None: - """Signal that generation is finished.""" - - @abstractmethod - async def get_requeue_batch(self) -> Any | None: - """Get a stale batch to regenerate, or None.""" - - @abstractmethod - def stats(self) -> dict: - """Buffer statistics for metrics.""" - - -class AsyncioExperienceBuffer(ExperienceBufferProtocol): - """Single-threaded asyncio-based buffer for Tinker backend. - - Uses asyncio.Queue for backpressure. Tinker's compute happens on remote - servers, so the Python process only orchestrates — no threading needed. - """ - - def __init__(self, max_size: int, max_staleness: int, requeue_stale: bool): - self._queue: asyncio.Queue[BufferedExperience | None] = asyncio.Queue(maxsize=max_size) - self._requeue_queue: asyncio.Queue = asyncio.Queue() - self._max_staleness = max_staleness - self._requeue_stale = requeue_stale - self._generation_complete = False - # Stats - self._total_produced = 0 - self._total_consumed = 0 - self._total_discarded = 0 - - async def put(self, experience: BufferedExperience) -> None: - await self._queue.put(experience) - self._total_produced += 1 - - async def get(self, current_policy_version: int) -> BufferedExperience | None: - while True: - if self._generation_complete and self._queue.empty(): - return None - try: - experience = await asyncio.wait_for(self._queue.get(), timeout=1.0) - except asyncio.TimeoutError: - if self._generation_complete and self._queue.empty(): - return None - continue - if experience is None: # sentinel - return None - # Staleness check - version_gap = current_policy_version - experience.policy_version - if version_gap > self._max_staleness: - self._total_discarded += 1 - if self._requeue_stale: - await self._requeue_queue.put(experience.batch_source) - continue - self._total_consumed += 1 - return experience - - def mark_generation_complete(self) -> None: - self._generation_complete = True - - async def get_requeue_batch(self) -> Any | None: - try: - return self._requeue_queue.get_nowait() - except asyncio.QueueEmpty: - return None - - def stats(self) -> dict: - return { - "async/buffer_size": self._queue.qsize(), - "async/total_produced": self._total_produced, - "async/total_consumed": self._total_consumed, - "async/total_discarded": self._total_discarded, - } diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 1e1027705..998be4989 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -106,6 +106,30 @@ async def generate_episodes( """ raise NotImplementedError("Subclasses must implement this method.") + async def generate_episodes_streaming( + self, + batch: TBatch, + agent_workflow_engine: UnifiedWorkflowEngine, + episode_queue, + is_validation: bool = False, + **kwargs, + ) -> None: + """Generate episodes and push each to episode_queue as it completes. + + Default: falls back to generate_episodes() and pushes all to queue. + Backends can override for true streaming. + + Args: + batch: The input batch. + agent_workflow_engine: The workflow engine to use. + episode_queue: asyncio.Queue to push (task_id, rollout_idx, result_idx, episode) tuples. + is_validation: Whether the generation is for validation. + **kwargs: Additional arguments. + """ + episodes = await self.generate_episodes(batch, agent_workflow_engine, is_validation, **kwargs) + for i, ep in enumerate(episodes): + await episode_queue.put((ep.task_id, getattr(ep, "rollout_idx", 0), i, ep)) + @abstractmethod def transform_to_backend_batch( self, @@ -217,8 +241,8 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: async def on_policy_updated(self, trainer_state: TrainerState) -> None: """Hook called immediately after update_policy(). Backends sync weights here. - For Tinker: save checkpoint, create new sampling_client. - For Verl (future): trigger NCCL sync to rollout workers. + For Tinker-like remote/distributed backends: save checkpoint, create new sampling_client. + For Verl-like colocated backends: trigger NCCL sync to rollout workers. Default: no-op (sync mode uses on_batch_end for this). """ pass diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py new file mode 100644 index 000000000..7f60dc6fc --- /dev/null +++ b/rllm/experimental/sync_coordinator.py @@ -0,0 +1,113 @@ +"""SyncCoordinator: manages rollout quotas and parameter sync timing for fully-async training.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + + +@dataclass +class SyncCoordinatorConfig: + """Configuration derived from trainer config and async config for the coordinator to use.""" + + train_batch_size: int # from config.data.train_batch_size + group_size: int # from config.rllm.rollout.n + staleness_threshold: float # from async config + trigger_parameter_sync_step: int # from async config + requeue_stale: bool # from async config + num_minibatches: int # from async config + + @property + def episodes_per_train_step(self) -> int: + """Number of episodes per training step.""" + return self.train_batch_size * self.group_size + + @property + def max_rollout_quota(self) -> int: + """Max episodes in-flight + in-queue between syncs.""" + return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.episodes_per_train_step) + + +class SyncCoordinator: + """Coordinates rollout scheduling and parameter sync between generation and training loops. + + Core responsibility: control how many episodes can be in-flight + queued, + and when to trigger weight synchronization. + """ + + def __init__(self, config: SyncCoordinatorConfig): + self.config = config + + # State + self._policy_version: int = 0 + self._scheduled_count: int = 0 # in-flight episodes not yet in buffer + self._stale_requeue_count: int = 0 # stale tasks to re-add to quota + self._steps_since_sync: int = 0 # training steps since last sync + self._total_stale_discarded: int = 0 + + # Events + self.sync_complete_event: asyncio.Event = asyncio.Event() + self.sync_complete_event.set() # initially unblocked + self.generation_done: bool = False + + @property + def policy_version(self) -> int: + return self._policy_version + + def compute_new_schedule_count(self, remain_in_queue: int) -> int: + """Compute how many new episodes the generation loop should schedule. + + Formula: new = max_rollout_quota - remain_in_queue - scheduled + requeue_stale + """ + available = self.config.max_rollout_quota - remain_in_queue - self._scheduled_count + self._stale_requeue_count + self._stale_requeue_count = 0 # consumed + return max(0, available) + + def on_episodes_scheduled(self, count: int) -> None: + """Called when the generation loop dispatches tasks.""" + self._scheduled_count += count + + def on_episode_generated(self, count: int = 1) -> None: + """Called when an episode arrives in the buffer (no longer in-flight).""" + self._scheduled_count = max(0, self._scheduled_count - count) + + def on_training_step_complete(self) -> None: + """Called after a gradient update.""" + self._steps_since_sync += 1 + + def should_sync(self) -> bool: + """Whether it's time to synchronize parameters.""" + return self._steps_since_sync >= self.config.trigger_parameter_sync_step + + def on_sync_complete(self) -> None: + """Called after weight sync. Bumps policy version, resets counters, signals gen loop.""" + self._policy_version += 1 + self._steps_since_sync = 0 + self.sync_complete_event.set() + + def on_stale_discarded(self, count: int, requeue: bool) -> None: + """Called when the training loop discards stale episodes.""" + self._total_stale_discarded += count + if requeue: + self._stale_requeue_count += count + + def is_episode_stale(self, ep_version: int) -> bool: + """Check if an episode is too stale to use. + + An episode is stale if the version gap exceeds: + staleness_threshold * trigger_parameter_sync_step + """ + if self.config.staleness_threshold == 0.0: + # On-policy: only current version is acceptable + return ep_version < self._policy_version + max_gap = self.config.staleness_threshold * self.config.trigger_parameter_sync_step + return (self._policy_version - ep_version) > max_gap + + def stats(self) -> dict: + return { + "async/policy_version": self._policy_version, + "async/scheduled_count": self._scheduled_count, + "async/steps_since_sync": self._steps_since_sync, + "async/total_stale_discarded": self._total_stale_discarded, + "async/max_rollout_quota": self.config.max_rollout_quota, + } diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index b3d84f6f4..d8d48bd2e 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -12,12 +12,12 @@ from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset -from rllm.experimental.async_config import AsyncTrainingConfig from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, ) from rllm.experimental.common.config import ( + AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, TransformConfig, @@ -31,9 +31,10 @@ from rllm.experimental.common.transform import _default_traj_grouping_hook, transform_episodes_to_trajectory_groups from rllm.experimental.common.visualization import visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine -from rllm.experimental.experience_buffer import AsyncioExperienceBuffer, BufferedExperience, ExperienceBufferProtocol +from rllm.experimental.episode_buffer import AsyncioEpisodeBuffer, BufferedEpisode from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine +from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata from rllm.workflows.workflow import TerminationReason, Workflow @@ -130,8 +131,10 @@ def __init__( async_cfg = self.rllm_config.get("async_training", {}) self.async_config = AsyncTrainingConfig( enabled=async_cfg.get("enabled", False), - max_staleness=async_cfg.get("max_staleness", 0), - buffer_size=async_cfg.get("buffer_size", 1), + staleness_threshold=async_cfg.get("staleness_threshold", 0.0), + trigger_parameter_sync_step=async_cfg.get("trigger_parameter_sync_step", 1), + partial_rollout=async_cfg.get("partial_rollout", True), + num_minibatches=async_cfg.get("num_minibatches", 1), requeue_stale=async_cfg.get("requeue_stale", True), ) @@ -255,13 +258,14 @@ async def fit_async(self) -> None: async def _fit_async(self, trainer_state: TrainerState) -> None: """Dispatch to sync or concurrent training based on config.""" + # TODO(listar2000): after some benchmarking, maybe we just keep the fully-async and treat on-policy as a special case. if self.async_config.enabled: - await self._fit_async_concurrent(trainer_state) + await self._fit_fully_async(trainer_state) else: - await self._fit_sync(trainer_state) + await self._fit_on_policy(trainer_state) - async def _fit_sync(self, trainer_state: TrainerState) -> None: - """Synchronous training loop (original behavior).""" + async def _fit_on_policy(self, trainer_state: TrainerState) -> None: + """Synchronous training loop (the most vanilla, standalone case that does not support minibatching or off-policy training).""" train_dataloader: Iterable = self.backend.get_dataloader(self.train_dataset, trainer_state) break_via_total_batches = False # used to break the training loop via the `total_batches` parameter use_total_batches = self.rllm_config.trainer.get("total_batches") is not None and self.rllm_config.trainer.total_batches > 0 @@ -376,13 +380,18 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N # Concurrent (async) training methods # ========================================================================= - async def _fit_async_concurrent(self, trainer_state: TrainerState) -> None: - """Concurrent generation + training with experience buffer.""" - buffer = AsyncioExperienceBuffer( - max_size=self.async_config.buffer_size, - max_staleness=self.async_config.max_staleness, + async def _fit_fully_async(self, trainer_state: TrainerState) -> None: + """Fully-async generation + training with episode-level streaming.""" + coord_config = SyncCoordinatorConfig( + train_batch_size=self.config.data.train_batch_size, + group_size=self.rllm_config.rollout.n, + staleness_threshold=self.async_config.staleness_threshold, + trigger_parameter_sync_step=self.async_config.trigger_parameter_sync_step, requeue_stale=self.async_config.requeue_stale, + num_minibatches=self.async_config.num_minibatches, ) + coordinator = SyncCoordinator(coord_config) + buffer = AsyncioEpisodeBuffer() # Compute total_steps for LR scheduling train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) @@ -393,111 +402,151 @@ async def _fit_async_concurrent(self, trainer_state: TrainerState) -> None: trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs await asyncio.gather( - self._generation_loop(trainer_state, buffer), - self._training_loop(trainer_state, buffer), + self._generation_loop(trainer_state, buffer, coordinator), + self._training_loop(trainer_state, buffer, coordinator), ) - async def _generation_loop(self, trainer_state: TrainerState, buffer: ExperienceBufferProtocol) -> None: - """Generate episodes and push to buffer. Runs concurrently with training.""" + async def _generation_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: + """Generate episodes and stream to buffer. Quota-controlled by SyncCoordinator.""" try: + group_size = self.rllm_config.rollout.n + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + for epoch in range(self.rllm_config.trainer.total_epochs): train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) - for batch in train_dataloader: - # Check for requeued stale batches first - requeue_batch = await buffer.get_requeue_batch() - if requeue_batch is not None: - batch = requeue_batch - # Snapshot current policy version BEFORE generation starts - gen_policy_version = trainer_state.policy_version + # Flatten to individual tasks + all_tasks = [] + for batch in train_dataloader: + all_tasks.extend(batch) - # Set training step metadata on workflow engine - self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) + task_idx = 0 + while task_idx < len(all_tasks): + # Check total_batches limit + if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: + return - # Stage 1: Generate episodes (async) - episodes = await self.backend.generate_episodes( - batch, - agent_workflow_engine=self.agent_workflow_engine, - is_validation=False, - ) - if not episodes: + # 1. Compute quota + quota = coordinator.compute_new_schedule_count(buffer.qsize()) + if quota <= 0: + coordinator.sync_complete_event.clear() + await coordinator.sync_complete_event.wait() continue - # Stage 2: Transform to trajectory groups (sync) - trajectory_groups, _ = transform_episodes_to_trajectory_groups( - episodes, - self.transform_config, - self.cf_config, - traj_grouping_hook=self.traj_grouping_hook, - ) - - # Stage 3: Rejection sampling (sync) - filtered_groups, filtered_episodes, _ = apply_rejection_sampling_and_filtering( - episodes, - trajectory_groups, - self.rs_config, - RejectionSamplingState(), - ) - if not filtered_groups: + # 2. Pick tasks up to quota (convert episodes to prompt count) + prompts_to_schedule = min(quota // group_size, len(all_tasks) - task_idx) + if prompts_to_schedule <= 0: + await asyncio.sleep(0.1) continue + chunk = all_tasks[task_idx : task_idx + prompts_to_schedule] + task_idx += prompts_to_schedule - # Push to buffer (blocks if full = backpressure) - experience = BufferedExperience( - trajectory_groups=filtered_groups, - episodes=filtered_episodes, - policy_version=gen_policy_version, - batch_source=batch, - ) - await buffer.put(experience) + total_episodes = prompts_to_schedule * group_size + coordinator.on_episodes_scheduled(total_episodes) + gen_policy_version = coordinator.policy_version - # Check total_batches limit - use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 - if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: - break + # Set training step metadata + self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) + + # 3. Stream generation — episodes pushed to episode_queue as they complete + episode_queue: asyncio.Queue = asyncio.Queue() + gen_task = asyncio.create_task(self.backend.generate_episodes_streaming(chunk, self.agent_workflow_engine, episode_queue, is_validation=False)) + + # 4. Drain episodes into buffer as they arrive + received = 0 + while received < total_episodes: + try: + task_id, rollout_idx, _, episode = await asyncio.wait_for(episode_queue.get(), timeout=300.0) + except asyncio.TimeoutError: + break + if episode is not None and episode.trajectories: + await buffer.put( + BufferedEpisode( + episode=episode, + policy_version=gen_policy_version, + task=episode.task, + task_id=episode.task_id, + ) + ) + coordinator.on_episode_generated(1) + received += 1 + + await gen_task # ensure streaming task completes finally: + coordinator.generation_done = True buffer.mark_generation_complete() - async def _training_loop(self, trainer_state: TrainerState, buffer: ExperienceBufferProtocol) -> None: - """Consume from buffer and train. Runs concurrently with generation.""" + async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: + """Collect episodes from buffer, group, train, and sync. Runs concurrently with generation.""" + episodes_per_step = coordinator.config.episodes_per_train_step + use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + while True: - experience = await buffer.get(current_policy_version=trainer_state.policy_version) - if experience is None: - break # Generation complete and buffer drained + # 1. Collect episodes for one training step + collected: list[BufferedEpisode] = [] + while len(collected) < episodes_per_step: + item = await buffer.get() + if item is None: + break # generation done + buffer empty + if coordinator.is_episode_stale(item.policy_version): + coordinator.on_stale_discarded(1, requeue=coordinator.config.requeue_stale) + continue + collected.append(item) + + if not collected: + if coordinator.generation_done and buffer.qsize() == 0: + break + continue - # Load experience into trainer_state + # 2. Extract episodes, transform to trajectory groups (at consumption time) + episodes = [be.episode for be in collected] trainer_state.reset_batch() - trainer_state.trajectory_groups = experience.trajectory_groups - trainer_state.episodes = experience.episodes - - # Collect workflow metrics - workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(trainer_state.episodes) + trainer_state.episodes = episodes + trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups( + episodes, + self.transform_config, + self.cf_config, + traj_grouping_hook=self.traj_grouping_hook, + ) + trainer_state.trajectory_groups = trajectory_groups + trainer_state.metrics.update(transform_metrics) + + # 3. Rejection sampling + filtered_groups, filtered_episodes, rs_metrics = apply_rejection_sampling_and_filtering( + episodes, + trajectory_groups, + self.rs_config, + RejectionSamplingState(), + ) + trainer_state.metrics.update(rs_metrics) + trainer_state.trajectory_groups = filtered_groups + trainer_state.episodes = filtered_episodes + if not trainer_state.has_trajectory_groups: + continue - # Stages 4-7: Backend training pipeline + # 4. Stages 4-7: backend training pipeline await self.backend.on_batch_start(trainer_state) - - # Stage 4: Transform to backend batch trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) - - # Stage 5: Process backend batch (forward pass) await self.backend.process_backend_batch(trainer_state) - - # Stage 6: Compute advantages await self.backend.compute_advantages(trainer_state, self.algorithm_config) - - # Stage 7: Update policy await self.backend.update_policy(trainer_state) - # Increment policy version (AFTER update, BEFORE on_policy_updated) - trainer_state.policy_version += 1 + # 5. Training step done — check sync + coordinator.on_training_step_complete() + if coordinator.should_sync(): + trainer_state.policy_version = coordinator.policy_version + 1 + await self.backend.on_policy_updated(trainer_state) + coordinator.on_sync_complete() - # Notify backend of policy update (sampling_client refresh for Tinker) - await self.backend.on_policy_updated(trainer_state) + # 6. Metrics, logging, visualization + workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(trainer_state.episodes) - # Stage 8: Logging, visualization, metrics + # Compute average staleness of consumed episodes + avg_staleness = np.mean([coordinator.policy_version - be.policy_version for be in collected]) + trainer_state.metrics["async/avg_episode_staleness"] = avg_staleness trainer_state.metrics.update(buffer.stats()) - trainer_state.metrics["async/experience_staleness"] = trainer_state.policy_version - 1 - experience.policy_version + trainer_state.metrics.update(coordinator.stats()) - # Visualization if self.tokenizer is not None: visualize_trajectory_last_steps( trainer_state.trajectory_groups, @@ -528,6 +577,10 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: ExperienceBu trainer_state.global_step += 1 + # Check total_batches limit + if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: + break + async def _validate_async(self, trainer_state: TrainerState) -> dict: """Validate the model (async implementation).""" n_val_samples = self.rllm_config.rollout.n_val diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 3738c3d51..d89b38fcd 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -225,6 +225,29 @@ async def generate_episodes( return episodes + async def generate_episodes_streaming( + self, + batch: Any, + agent_workflow_engine: UnifiedWorkflowEngine, + episode_queue, + is_validation: bool = False, + **kwargs, + ) -> None: + """Generate episodes using streaming — push each to queue as it completes. + + Same setup as generate_episodes but uses execute_tasks_streaming. + """ + assert self.rollout_engine is not None, "rollout_engine is not initialized" + assert self.sampling_client is not None, "sampling_client is not initialized" + + self.rollout_engine.set_sampling_client(self.sampling_client) + + group_size = self.full_config.rllm.rollout.n_val if is_validation else self.full_config.rllm.rollout.n + interleaved_batch = _build_interleave_batch(batch, group_size) + task_ids = [item["uid"] for item in interleaved_batch] + + await agent_workflow_engine.execute_tasks_streaming(interleaved_batch, task_ids, queue=episode_queue, is_validation=is_validation, **kwargs) + def transform_to_backend_batch( self, trainer_state: TrainerState, From fb85d2a8f625cd74c6d6b64531938402fa846822 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:41:53 -0500 Subject: [PATCH 03/21] cherrypick Kyle's async design refinements from kyle/deepresearch Adopts core async architecture improvements: BufferedEpisodeGroup with EpisodeGroupAccumulator, simplified SyncCoordinator with throttle and pause/resume, fire-and-forget generation loop, streaming gradient accumulation, and weight sync gate mechanism on RolloutEngine. Co-Authored-By: Claude Opus 4.6 --- rllm/experimental/common/__init__.py | 2 + rllm/experimental/common/config.py | 71 +++- rllm/experimental/config/rllm/base.yaml | 27 +- rllm/experimental/episode_buffer.py | 84 +++-- rllm/experimental/rollout/rollout_engine.py | 42 ++- rllm/experimental/sync_coordinator.py | 125 ++++--- rllm/experimental/unified_trainer.py | 341 ++++++++++++-------- 7 files changed, 461 insertions(+), 231 deletions(-) diff --git a/rllm/experimental/common/__init__.py b/rllm/experimental/common/__init__.py index b262590ea..b75f43c90 100644 --- a/rllm/experimental/common/__init__.py +++ b/rllm/experimental/common/__init__.py @@ -10,6 +10,7 @@ AsyncTrainingConfig, CompactFilteringConfig, RejectionSamplingConfig, + RolloutCorrectionConfig, TransformConfig, rLLMAdvantageEstimator, ) @@ -28,6 +29,7 @@ "AsyncTrainingConfig", "CompactFilteringConfig", "RejectionSamplingConfig", + "RolloutCorrectionConfig", "TransformConfig", "AlgorithmConfig", # Transform pipeline diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 617e36797..4f57691b7 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -14,21 +14,26 @@ class AsyncTrainingConfig: When `enabled` is False, the trainer uses the current synchronous pipeline. When `enabled` is True, the trainer runs concurrent generation + training - with episode-level streaming and staleness-based filtering. + with group-level streaming and dispatch-time throttle. - Behavior spectrum (following the Verl fully-async pattern): - - staleness_threshold=0, trigger_parameter_sync_step=1: On-policy (panel a) - - staleness_threshold=0, trigger_parameter_sync_step=K: Stream off-policy (panel b) - - staleness_threshold>0, partial_rollout=False: Async with staleness (panel c) - - staleness_threshold>0, partial_rollout=True: Async with partial rollout (panel d) + Behavior spectrum: + - staleness_threshold=0, trigger_parameter_sync_step=1: On-policy + - staleness_threshold=0, trigger_parameter_sync_step=K: Stream off-policy + - staleness_threshold>0, partial_rollout=False: Async with staleness + - staleness_threshold>0, partial_rollout=True: Async with partial rollout """ enabled: bool = False - staleness_threshold: float = 0.0 # 0.0 = on-policy. Fraction of extra samples allowed. - trigger_parameter_sync_step: int = 1 # gradient updates between weight syncs - partial_rollout: bool = True # True = don't wait for in-flight to finish at sync - num_minibatches: int = 1 # gradient accumulation within a training step - requeue_stale: bool = True # re-schedule stale episodes' tasks for generation + mini_batch_size: int = 1 # episode groups per optimizer step + streaming_chunks: int = 1 # forward-backward passes per optimizer step (must divide mini_batch_size) + staleness_threshold: float = 0.0 # 0.0 = on-policy. Controls dispatch throttle quota. + trigger_parameter_sync_step: int = 1 # optimizer steps between weight sync + version bump + partial_rollout: bool = True # enable turn-level gating during weight sync + + def __post_init__(self): + if self.enabled: + assert self.streaming_chunks >= 1 + assert self.mini_batch_size % self.streaming_chunks == 0, f"mini_batch_size ({self.mini_batch_size}) must be divisible by streaming_chunks ({self.streaming_chunks})" @dataclass @@ -108,6 +113,30 @@ class RejectionSamplingConfig: # For "episode" mode (verl compatibility): minimum number of tasks with partial solves before proceeding min_partial_solve_tasks: int = 1 + # Filter out episode groups where all rollouts have the same is_correct (no gradient signal). + # Applied at the accumulator level in async training, before groups enter the buffer. + filter_uniform_groups: bool = False + + +@dataclass +class RolloutCorrectionConfig: + """Configuration for rollout correction (TIS, proximal forward passes). + + Backend-agnostic — each backend interprets these according to its infrastructure. + + Attributes: + mode: None = disabled (string loss names, current behavior). + "token" or "sequence" = enable custom callable loss with TIS at that level. + bypass_mode: When True, use rollout (inference) logprobs as π_old — no + proximal forward pass. When False, compute π_old via policy.forward() + (3-policy / decoupled PPO). + tis_cap: Upper clamp on the TIS importance weight. + """ + + mode: str | None = None + bypass_mode: bool = True + tis_cap: float = 5.0 + class rLLMAdvantageEstimator(str, Enum): """ @@ -142,10 +171,17 @@ class AlgorithmConfig: # When False (default), always compute advantages normally. use_precomputed_advantage: bool = False # for tinker backend only - loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro", "cross_entropy"] | None = None + loss_fn: Literal["importance_sampling", "ppo", "cispo", "dro", "cross_entropy", "grpo", "dapo", "gspo"] | None = None lr_schedule: Literal["linear", "cosine", "constant"] = "constant" warmup_steps_ratio: float = 0.0 + # Custom loss / rollout correction fields (used by Fireworks backend with cookbook losses) + kl_beta: float = 0.0 + eps_clip: float = 0.2 + eps_clip_high: float | None = None + rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig) + router_replay: bool = False + @classmethod def from_config(cls, config: DictConfig) -> "AlgorithmConfig": """Create an AlgorithmConfig from a dictionary configuration. @@ -155,6 +191,12 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": Returns: AlgorithmConfig: The AlgorithmConfig built from the configuration. """ + rc_section = config.rllm.algorithm.get("rollout_correction", {}) + rollout_correction = RolloutCorrectionConfig( + mode=rc_section.get("mode", None), + bypass_mode=rc_section.get("bypass_mode", True), + tis_cap=rc_section.get("tis_cap", 5.0), + ) return cls( estimator=rLLMAdvantageEstimator(config.algorithm.adv_estimator), stepwise_advantage_mode=config.rllm.stepwise_advantage.mode, @@ -164,6 +206,11 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": loss_fn=config.rllm.algorithm.get("loss_fn", None), lr_schedule=config.rllm.algorithm.get("lr_schedule", "constant"), warmup_steps_ratio=config.rllm.algorithm.get("warmup_steps_ratio", 0.0), + kl_beta=config.rllm.algorithm.get("kl_beta", 0.0), + eps_clip=config.rllm.algorithm.get("eps_clip", 0.2), + eps_clip_high=config.rllm.algorithm.get("eps_clip_high", None), + rollout_correction=rollout_correction, + router_replay=config.rllm.algorithm.get("router_replay", False), ) def __post_init__(self): diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index 2ad1fb171..f6f94e688 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -59,6 +59,20 @@ algorithm: # for tinker backend only (avaiable options: importance_sampling, ppo, cispo, dro, cross_entropy) loss_fn: null + # Custom loss / rollout correction (used by Fireworks backend with cookbook losses) + kl_beta: 0.0 # KL penalty coefficient; >0 enables reference forward pass + eps_clip: 0.2 # PPO clip epsilon + eps_clip_high: null # Asymmetric upper clip bound (null = symmetric) + + # Router Replay (R3): replay MoE expert routing from inference during training + router_replay: false + + # Rollout correction: corrects FP8 (inference) vs FP32 (training) drift + rollout_correction: + mode: null # null = disabled, "token" or "sequence" = enable custom callable loss + bypass_mode: true # true = use rollout logprobs as pi_old (2-policy), false = proximal forward (3-policy) + tis_cap: 5.0 # Upper clamp on TIS importance weight + # Stepwise advantage # TODO(listar2000): deprecate the `per_step` mode and refactor this config. stepwise_advantage: @@ -93,6 +107,7 @@ rejection_sample: multiplier: 1 min_partial_solve_tasks: 1 min_trajs_per_group: 2 + filter_uniform_groups: false # SDK Configuration sdk: @@ -109,12 +124,12 @@ sdk: # Async Training Configuration async_training: - enabled: false # When false, use current synchronous pipeline - staleness_threshold: 0.0 # 0.0 = on-policy. Fraction of extra samples allowed. - trigger_parameter_sync_step: 1 # gradient updates between weight syncs - partial_rollout: true # True = don't wait for in-flight to finish at sync - num_minibatches: 1 # gradient accumulation within a training step - requeue_stale: true # re-schedule stale episodes' tasks for generation + enabled: false + mini_batch_size: 1 + streaming_chunks: 1 + staleness_threshold: 0.0 + trigger_parameter_sync_step: 1 + partial_rollout: true # Episode Logging Configuration episode_logging: diff --git a/rllm/experimental/episode_buffer.py b/rllm/experimental/episode_buffer.py index 32f1b5e8b..df89e0719 100644 --- a/rllm/experimental/episode_buffer.py +++ b/rllm/experimental/episode_buffer.py @@ -1,8 +1,7 @@ """Episode buffer protocol and asyncio implementation for async training. -The buffer is a dumb pipe — no staleness filtering. Staleness is checked -at consumption time by the training loop. Backpressure is managed externally -by SyncCoordinator's rollout quota. +The buffer is a dumb pipe — no staleness filtering. Staleness is controlled +at dispatch time by SyncCoordinator's throttle quota. """ from __future__ import annotations @@ -15,13 +14,61 @@ @dataclass -class BufferedEpisode: - """A single episode stored in the buffer.""" +class BufferedEpisodeGroup: + """All n episodes for one prompt, collected before buffering.""" - episode: Episode - policy_version: int - task: dict # original task dict (for requeuing) - task_id: str # denormalized from episode.task_id + episodes: list[Episode] + weight_version: int # earliest weight_version across all steps in all episodes + task_id: str + + +class EpisodeGroupAccumulator: + """Per-task collector that groups episodes by task_id before pushing to buffer. + + Lives in the generation loop, NOT inside the buffer (buffer stays a dumb pipe). + Optionally filters out groups with no gradient signal (all correct or all incorrect). + """ + + def __init__( + self, + group_size: int, + buffer: EpisodeBufferProtocol, + filter_uniform_groups: bool = False, + on_group_filtered: callable | None = None, + ): + self._group_size = group_size + self._buffer = buffer + self._filter_uniform_groups = filter_uniform_groups + self._on_group_filtered = on_group_filtered + self._pending: dict[str, list[Episode]] = {} + self.total_filtered: int = 0 + + async def add_episode(self, task_id: str, episode: Episode) -> bool: + """Add episode. Returns True if group completed (pushed or filtered).""" + self._pending.setdefault(task_id, []).append(episode) + if len(self._pending[task_id]) == self._group_size: + episodes = self._pending.pop(task_id) + + if self._filter_uniform_groups and len({ep.is_correct for ep in episodes}) == 1: + self.total_filtered += 1 + if self._on_group_filtered: + self._on_group_filtered() + return True + + earliest = self._compute_earliest_version(episodes) + await self._buffer.put(BufferedEpisodeGroup(episodes=episodes, weight_version=earliest, task_id=task_id)) + return True + return False + + @staticmethod + def _compute_earliest_version(episodes: list[Episode]) -> int: + min_v = float("inf") + for ep in episodes: + for traj in ep.trajectories: + for step in traj.steps: + if step.weight_version is not None: + min_v = min(min_v, step.weight_version) + return int(min_v) if min_v != float("inf") else 0 class EpisodeBufferProtocol(ABC): @@ -33,12 +80,12 @@ class EpisodeBufferProtocol(ABC): """ @abstractmethod - async def put(self, item: BufferedEpisode) -> None: - """Add an episode to the buffer.""" + async def put(self, item: BufferedEpisodeGroup) -> None: + """Add an episode group to the buffer.""" @abstractmethod - async def get(self) -> BufferedEpisode | None: - """Get next episode. Returns None when generation is done and buffer is empty.""" + async def get(self) -> BufferedEpisodeGroup | None: + """Get next episode group. Returns None when generation is done and buffer is empty.""" @abstractmethod def mark_generation_complete(self) -> None: @@ -46,7 +93,7 @@ def mark_generation_complete(self) -> None: @abstractmethod def qsize(self) -> int: - """Current number of episodes in the buffer.""" + """Current number of episode groups in the buffer.""" @abstractmethod def stats(self) -> dict: @@ -56,23 +103,22 @@ def stats(self) -> dict: class AsyncioEpisodeBuffer(EpisodeBufferProtocol): """Unbounded asyncio.Queue-based buffer for Tinker backend. - No staleness filtering — that happens at consumption time. - No max queue size — quota controls growth externally via SyncCoordinator. + No staleness filtering — throttle controls growth externally via SyncCoordinator. Tinker's compute happens on remote servers, so the Python process only orchestrates — no threading needed. """ def __init__(self): - self._queue: asyncio.Queue[BufferedEpisode | None] = asyncio.Queue() # unbounded + self._queue: asyncio.Queue[BufferedEpisodeGroup | None] = asyncio.Queue() # unbounded self._generation_complete = False self._total_produced = 0 self._total_consumed = 0 - async def put(self, item: BufferedEpisode) -> None: + async def put(self, item: BufferedEpisodeGroup) -> None: await self._queue.put(item) self._total_produced += 1 - async def get(self) -> BufferedEpisode | None: + async def get(self) -> BufferedEpisodeGroup | None: while True: if self._generation_complete and self._queue.empty(): return None diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index ceb9c603e..ca9ab0e9f 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,3 +1,4 @@ +import asyncio from dataclasses import dataclass from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput @@ -16,9 +17,11 @@ class ModelOutput: multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None + weight_version: int | None = None # policy version at time of generation def to_dict(self): return { @@ -34,6 +37,7 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, + "weight_version": self.weight_version, } @classmethod @@ -51,6 +55,7 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), + weight_version=data.get("weight_version"), ) @@ -60,7 +65,42 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - pass + # Gate mechanism for pausing model calls during weight sync + self._gate: asyncio.Event = asyncio.Event() + self._gate.set() # open by default + self._active_calls: int = 0 + self._drained_event: asyncio.Event = asyncio.Event() + self._drained_event.set() # initially drained (no active calls) + self.weight_version: int = 0 + + # --- Gate mechanism --- + + def close_gate(self) -> None: + """Close the gate. New model calls will block at wait_for_gate().""" + self._gate.clear() + + def open_gate(self) -> None: + """Open the gate, releasing any blocked model calls.""" + self._gate.set() + + async def wait_for_gate(self) -> None: + """Wait until gate is open, then register as active call. + Engines must call this at the START of get_model_response().""" + await self._gate.wait() + self._active_calls += 1 + self._drained_event.clear() + + def on_model_call_complete(self) -> None: + """Unregister active call. Engines must call this at the END of + get_model_response() (in a finally block).""" + self._active_calls -= 1 + if self._active_calls <= 0: + self._active_calls = 0 + self._drained_event.set() + + async def wait_for_drain(self) -> None: + """Wait until all active model calls complete. Used during weight sync.""" + await self._drained_event.wait() async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: raise NotImplementedError("get_model_response is not implemented") diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index 7f60dc6fc..cbef5fd88 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -8,106 +8,99 @@ @dataclass class SyncCoordinatorConfig: - """Configuration derived from trainer config and async config for the coordinator to use.""" - - train_batch_size: int # from config.data.train_batch_size - group_size: int # from config.rllm.rollout.n - staleness_threshold: float # from async config - trigger_parameter_sync_step: int # from async config - requeue_stale: bool # from async config - num_minibatches: int # from async config - - @property - def episodes_per_train_step(self) -> int: - """Number of episodes per training step.""" - return self.train_batch_size * self.group_size + mini_batch_size: int # episode groups per optimizer step + group_size: int # episodes per group (rollout.n) + staleness_threshold: float + trigger_parameter_sync_step: int @property def max_rollout_quota(self) -> int: - """Max episodes in-flight + in-queue between syncs.""" - return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.episodes_per_train_step) + """Max outstanding groups (dispatched but not yet consumed by training).""" + return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.mini_batch_size) class SyncCoordinator: - """Coordinates rollout scheduling and parameter sync between generation and training loops. - - Core responsibility: control how many episodes can be in-flight + queued, - and when to trigger weight synchronization. - """ + """Coordinates rollout scheduling and parameter sync between generation and training loops.""" def __init__(self, config: SyncCoordinatorConfig): self.config = config - # State self._policy_version: int = 0 - self._scheduled_count: int = 0 # in-flight episodes not yet in buffer - self._stale_requeue_count: int = 0 # stale tasks to re-add to quota - self._steps_since_sync: int = 0 # training steps since last sync - self._total_stale_discarded: int = 0 - - # Events - self.sync_complete_event: asyncio.Event = asyncio.Event() - self.sync_complete_event.set() # initially unblocked + self._outstanding: int = 0 # groups dispatched but not yet consumed by training + self._steps_since_sync: int = 0 + self._total_syncs: int = 0 + self._total_groups_filtered: int = 0 + + # Throttle — blocks generation when outstanding >= max_rollout_quota + self._throttle_event: asyncio.Event = asyncio.Event() + self._throttle_event.set() + + # Generation pause — blocks generation during validation or weight sync + self._generation_paused: asyncio.Event = asyncio.Event() + self._generation_paused.set() + self.generation_done: bool = False @property def policy_version(self) -> int: return self._policy_version - def compute_new_schedule_count(self, remain_in_queue: int) -> int: - """Compute how many new episodes the generation loop should schedule. + # --- Throttle --- - Formula: new = max_rollout_quota - remain_in_queue - scheduled + requeue_stale - """ - available = self.config.max_rollout_quota - remain_in_queue - self._scheduled_count + self._stale_requeue_count - self._stale_requeue_count = 0 # consumed - return max(0, available) + def on_group_dispatched(self) -> None: + """Generation loop dispatched one prompt (n rollouts).""" + self._outstanding += 1 + if self._outstanding >= self.config.max_rollout_quota: + self._throttle_event.clear() - def on_episodes_scheduled(self, count: int) -> None: - """Called when the generation loop dispatches tasks.""" - self._scheduled_count += count + def on_group_consumed(self) -> None: + """Training loop consumed one group from the buffer.""" + self._outstanding = max(0, self._outstanding - 1) + self._throttle_event.set() - def on_episode_generated(self, count: int = 1) -> None: - """Called when an episode arrives in the buffer (no longer in-flight).""" - self._scheduled_count = max(0, self._scheduled_count - count) + def on_group_filtered(self) -> None: + """Accumulator filtered out a uniform group. Frees throttle slot and tracks count.""" + self._total_groups_filtered += 1 + self.on_group_consumed() + + async def wait_for_throttle(self) -> None: + """Generation loop blocks here when quota is full.""" + await self._throttle_event.wait() + + def has_quota(self) -> bool: + """Whether the generation loop can dispatch another group.""" + return self._outstanding < self.config.max_rollout_quota + + # --- Weight sync --- def on_training_step_complete(self) -> None: - """Called after a gradient update.""" self._steps_since_sync += 1 def should_sync(self) -> bool: - """Whether it's time to synchronize parameters.""" return self._steps_since_sync >= self.config.trigger_parameter_sync_step def on_sync_complete(self) -> None: - """Called after weight sync. Bumps policy version, resets counters, signals gen loop.""" self._policy_version += 1 self._steps_since_sync = 0 - self.sync_complete_event.set() - - def on_stale_discarded(self, count: int, requeue: bool) -> None: - """Called when the training loop discards stale episodes.""" - self._total_stale_discarded += count - if requeue: - self._stale_requeue_count += count - - def is_episode_stale(self, ep_version: int) -> bool: - """Check if an episode is too stale to use. - - An episode is stale if the version gap exceeds: - staleness_threshold * trigger_parameter_sync_step - """ - if self.config.staleness_threshold == 0.0: - # On-policy: only current version is acceptable - return ep_version < self._policy_version - max_gap = self.config.staleness_threshold * self.config.trigger_parameter_sync_step - return (self._policy_version - ep_version) > max_gap + self._total_syncs += 1 + + # --- Generation pause (for validation / non-partial weight sync) --- + + def pause_generation(self) -> None: + self._generation_paused.clear() + + def resume_generation(self) -> None: + self._generation_paused.set() + + async def wait_for_generation_allowed(self) -> None: + await self._generation_paused.wait() def stats(self) -> dict: return { "async/policy_version": self._policy_version, - "async/scheduled_count": self._scheduled_count, + "async/outstanding_groups": self._outstanding, "async/steps_since_sync": self._steps_since_sync, - "async/total_stale_discarded": self._total_stale_discarded, "async/max_rollout_quota": self.config.max_rollout_quota, + "async/total_syncs": self._total_syncs, + "async/total_groups_filtered": self._total_groups_filtered, } diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index d8d48bd2e..d9aa54848 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -1,5 +1,6 @@ import asyncio import time +import uuid from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Callable, Iterable @@ -28,10 +29,13 @@ RejectionSamplingState, apply_rejection_sampling_and_filtering, ) -from rllm.experimental.common.transform import _default_traj_grouping_hook, transform_episodes_to_trajectory_groups +from rllm.experimental.common.transform import ( + _default_traj_grouping_hook, + transform_episodes_to_trajectory_groups, +) from rllm.experimental.common.visualization import visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine -from rllm.experimental.episode_buffer import AsyncioEpisodeBuffer, BufferedEpisode +from rllm.experimental.episode_buffer import AsyncioEpisodeBuffer, BufferedEpisodeGroup, EpisodeGroupAccumulator from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig @@ -131,11 +135,11 @@ def __init__( async_cfg = self.rllm_config.get("async_training", {}) self.async_config = AsyncTrainingConfig( enabled=async_cfg.get("enabled", False), + mini_batch_size=async_cfg.get("mini_batch_size", 1), + streaming_chunks=async_cfg.get("streaming_chunks", 1), staleness_threshold=async_cfg.get("staleness_threshold", 0.0), trigger_parameter_sync_step=async_cfg.get("trigger_parameter_sync_step", 1), partial_rollout=async_cfg.get("partial_rollout", True), - num_minibatches=async_cfg.get("num_minibatches", 1), - requeue_stale=async_cfg.get("requeue_stale", True), ) rollout_engine: RolloutEngine = self.backend.init_rollout_engine( @@ -186,6 +190,7 @@ def _validate_and_setup_configs(self): mode=rs_mode, min_partial_solve_tasks=self.rllm_config.rejection_sample.min_partial_solve_tasks, min_trajs_per_group=self.rllm_config.rejection_sample.min_trajs_per_group, + filter_uniform_groups=self.rllm_config.rejection_sample.get("filter_uniform_groups", False), ) # algorithm config (used for rLLM-native advantage computation) @@ -381,14 +386,13 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N # ========================================================================= async def _fit_fully_async(self, trainer_state: TrainerState) -> None: - """Fully-async generation + training with episode-level streaming.""" + """Fully-async generation + training with group-level streaming.""" + assert self.config.data.train_batch_size == 1, f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" coord_config = SyncCoordinatorConfig( - train_batch_size=self.config.data.train_batch_size, + mini_batch_size=self.async_config.mini_batch_size, group_size=self.rllm_config.rollout.n, staleness_threshold=self.async_config.staleness_threshold, trigger_parameter_sync_step=self.async_config.trigger_parameter_sync_step, - requeue_stale=self.async_config.requeue_stale, - num_minibatches=self.async_config.num_minibatches, ) coordinator = SyncCoordinator(coord_config) buffer = AsyncioEpisodeBuffer() @@ -407,147 +411,174 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: ) async def _generation_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: - """Generate episodes and stream to buffer. Quota-controlled by SyncCoordinator.""" - try: - group_size = self.rllm_config.rollout.n - use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + """Generate episodes and stream to buffer. Continuous fire-and-forget per prompt.""" + group_size = self.rllm_config.rollout.n + accumulator = EpisodeGroupAccumulator( + group_size=group_size, + buffer=buffer, + filter_uniform_groups=self.rs_config.filter_uniform_groups, + on_group_filtered=coordinator.on_group_filtered, + ) + try: for epoch in range(self.rllm_config.trainer.total_epochs): train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) + self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) - # Flatten to individual tasks - all_tasks = [] for batch in train_dataloader: - all_tasks.extend(batch) - - task_idx = 0 - while task_idx < len(all_tasks): - # Check total_batches limit - if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: - return - - # 1. Compute quota - quota = coordinator.compute_new_schedule_count(buffer.qsize()) - if quota <= 0: - coordinator.sync_complete_event.clear() - await coordinator.sync_complete_event.wait() - continue - - # 2. Pick tasks up to quota (convert episodes to prompt count) - prompts_to_schedule = min(quota // group_size, len(all_tasks) - task_idx) - if prompts_to_schedule <= 0: - await asyncio.sleep(0.1) - continue - chunk = all_tasks[task_idx : task_idx + prompts_to_schedule] - task_idx += prompts_to_schedule - - total_episodes = prompts_to_schedule * group_size - coordinator.on_episodes_scheduled(total_episodes) - gen_policy_version = coordinator.policy_version - - # Set training step metadata - self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) - - # 3. Stream generation — episodes pushed to episode_queue as they complete - episode_queue: asyncio.Queue = asyncio.Queue() - gen_task = asyncio.create_task(self.backend.generate_episodes_streaming(chunk, self.agent_workflow_engine, episode_queue, is_validation=False)) - - # 4. Drain episodes into buffer as they arrive - received = 0 - while received < total_episodes: - try: - task_id, rollout_idx, _, episode = await asyncio.wait_for(episode_queue.get(), timeout=300.0) - except asyncio.TimeoutError: - break - if episode is not None and episode.trajectories: - await buffer.put( - BufferedEpisode( - episode=episode, - policy_version=gen_policy_version, - task=episode.task, - task_id=episode.task_id, - ) - ) - coordinator.on_episode_generated(1) - received += 1 - - await gen_task # ensure streaming task completes + # async training uses train_batch_size=1 + task = batch[0] + + # Block during validation / non-partial sync + await coordinator.wait_for_generation_allowed() + + # Dispatch-time throttle: block if quota exhausted + if not coordinator.has_quota(): + await coordinator.wait_for_throttle() + + coordinator.on_group_dispatched() + + # Generate a unique task_id for this prompt + task_id = str(uuid.uuid4()) + + # Fire-and-forget n rollout tasks for this prompt + for rollout_idx in range(group_size): + + async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): + try: + _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry(task=t, task_id=tid, rollout_idx=ridx, result_idx=0) + await accumulator.add_episode(tid, episode) + except Exception: + # Group can never complete — free the throttle slot to prevent deadlock + coordinator.on_group_consumed() + raise + + asyncio.create_task(_run_rollout()) + + # Wait for all in-flight rollouts to finish before marking generation complete + await self._wait_for_all_workflows_idle() finally: coordinator.generation_done = True buffer.mark_generation_complete() async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: - """Collect episodes from buffer, group, train, and sync. Runs concurrently with generation.""" - episodes_per_step = coordinator.config.episodes_per_train_step + """Collect episode groups from buffer, train with streaming grad accumulation. Runs concurrently with generation.""" + mini_batch_size = self.async_config.mini_batch_size + streaming_chunks = self.async_config.streaming_chunks + groups_per_chunk = mini_batch_size // streaming_chunks use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 + rollout_engine = self.agent_workflow_engine.rollout_engine while True: - # 1. Collect episodes for one training step - collected: list[BufferedEpisode] = [] - while len(collected) < episodes_per_step: - item = await buffer.get() - if item is None: - break # generation done + buffer empty - if coordinator.is_episode_stale(item.policy_version): - coordinator.on_stale_discarded(1, requeue=coordinator.config.requeue_stale) + trainer_state.reset_batch() + step_start = time.perf_counter() + all_collected: list[BufferedEpisodeGroup] = [] + all_episodes: list[Episode] = [] + buffer_wait_time = 0.0 + + # 1. Streaming gradient accumulation across chunks + for chunk_idx in range(streaming_chunks): + # Pull groups_per_chunk groups from buffer + chunk_groups: list[BufferedEpisodeGroup] = [] + while len(chunk_groups) < groups_per_chunk: + t0 = time.perf_counter() + item = await buffer.get() + buffer_wait_time += time.perf_counter() - t0 + if item is None: + break # generation done + buffer empty + chunk_groups.append(item) + + if not chunk_groups: + break + + for _ in chunk_groups: + coordinator.on_group_consumed() + all_collected.extend(chunk_groups) + + # Flatten episodes from groups + episodes = [] + for group in chunk_groups: + episodes.extend(group.episodes) + all_episodes.extend(episodes) + + # Transform → rejection sampling → backend pipeline + trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups( + episodes, + self.transform_config, + self.cf_config, + traj_grouping_hook=self.traj_grouping_hook, + ) + trainer_state.trajectory_groups = trajectory_groups + trainer_state.episodes = episodes + trainer_state.metrics.update(transform_metrics) + + filtered_groups, filtered_episodes, rs_metrics = apply_rejection_sampling_and_filtering( + episodes, + trajectory_groups, + self.rs_config, + RejectionSamplingState(), + ) + trainer_state.metrics.update(rs_metrics) + trainer_state.trajectory_groups = filtered_groups + trainer_state.episodes = filtered_episodes + if not trainer_state.has_trajectory_groups: continue - collected.append(item) - if not collected: + await self.backend.on_batch_start(trainer_state) + trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) + await self.backend.process_backend_batch(trainer_state) + await self.backend.compute_advantages(trainer_state, self.algorithm_config) + + if not all_collected: if coordinator.generation_done and buffer.qsize() == 0: break continue - # 2. Extract episodes, transform to trajectory groups (at consumption time) - episodes = [be.episode for be in collected] - trainer_state.reset_batch() - trainer_state.episodes = episodes - trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups( - episodes, - self.transform_config, - self.cf_config, - traj_grouping_hook=self.traj_grouping_hook, - ) - trainer_state.trajectory_groups = trajectory_groups - trainer_state.metrics.update(transform_metrics) - - # 3. Rejection sampling - filtered_groups, filtered_episodes, rs_metrics = apply_rejection_sampling_and_filtering( - episodes, - trajectory_groups, - self.rs_config, - RejectionSamplingState(), - ) - trainer_state.metrics.update(rs_metrics) - trainer_state.trajectory_groups = filtered_groups - trainer_state.episodes = filtered_episodes - if not trainer_state.has_trajectory_groups: - continue - - # 4. Stages 4-7: backend training pipeline - await self.backend.on_batch_start(trainer_state) - trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) - await self.backend.process_backend_batch(trainer_state) - await self.backend.compute_advantages(trainer_state, self.algorithm_config) + # 2. Single optimizer step + trainer_state.episodes = all_episodes await self.backend.update_policy(trainer_state) - # 5. Training step done — check sync + # 3. Training step done — check sync coordinator.on_training_step_complete() + sync_time = 0.0 if coordinator.should_sync(): - trainer_state.policy_version = coordinator.policy_version + 1 - await self.backend.on_policy_updated(trainer_state) - coordinator.on_sync_complete() - - # 6. Metrics, logging, visualization - workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(trainer_state.episodes) - - # Compute average staleness of consumed episodes - avg_staleness = np.mean([coordinator.policy_version - be.policy_version for be in collected]) - trainer_state.metrics["async/avg_episode_staleness"] = avg_staleness - trainer_state.metrics.update(buffer.stats()) + t0 = time.perf_counter() + await self._perform_weight_sync(trainer_state, coordinator, rollout_engine) + sync_time = time.perf_counter() - t0 + + # 4. Metrics, logging, visualization + workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(all_episodes) + + staleness_values = [coordinator.policy_version - g.weight_version for g in all_collected] + trainer_state.metrics["async/staleness_mean"] = np.mean(staleness_values) + trainer_state.metrics["async/staleness_min"] = np.min(staleness_values) + trainer_state.metrics["async/staleness_max"] = np.max(staleness_values) + trainer_state.metrics["async/groups_consumed"] = len(all_collected) + + # Timing + trainer_state.metrics["time/step"] = time.perf_counter() - step_start + trainer_state.metrics["time/buffer_wait"] = buffer_wait_time + if sync_time > 0: + trainer_state.metrics["time/weight_sync"] = sync_time + + # Weight version delta within trajectories (meaningful in partial_rollout mode) + traj_deltas = [] + for ep in all_episodes: + for traj in ep.trajectories: + versions = [s.weight_version for s in traj.steps if s.weight_version is not None] + if len(versions) >= 2: + traj_deltas.append(max(versions) - min(versions)) + if traj_deltas: + trainer_state.metrics["async/traj_weight_delta_mean"] = np.mean(traj_deltas) + trainer_state.metrics["async/traj_weight_delta_min"] = np.min(traj_deltas) + trainer_state.metrics["async/traj_weight_delta_max"] = np.max(traj_deltas) + + buffer_stats = buffer.stats() + trainer_state.metrics["async/gen_train_ratio"] = buffer_stats["async/total_produced"] / max(trainer_state.global_step, 1) + trainer_state.metrics.update(buffer_stats) trainer_state.metrics.update(coordinator.stats()) - if self.tokenizer is not None: + if self.tokenizer is not None and trainer_state.has_trajectory_groups: visualize_trajectory_last_steps( trainer_state.trajectory_groups, tokenizer=self.tokenizer, @@ -567,13 +598,13 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpiso self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, - episodes=trainer_state.episodes, + episodes=all_episodes, trajectory_groups=trainer_state.trajectory_groups, ) # Periodic validation if self.rllm_config.trainer.test_freq > 0 and trainer_state.global_step % self.rllm_config.trainer.test_freq == 0: - await self._validate_async(trainer_state) + await self._validate_async_with_pause(trainer_state, coordinator) trainer_state.global_step += 1 @@ -581,6 +612,49 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpiso if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: break + async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine) -> None: + """Synchronize weights between training and rollout engines. + + Two modes depending on partial_rollout: + - partial_rollout=True: Uses rollout engine gate (model-call level). + Workflows block between turns, resume with new weights. + - partial_rollout=False: Uses coordinator generation pause (dispatch level). + Workflows finish naturally, gate stays open. + """ + if self.async_config.partial_rollout: + # Block new model calls; in-flight calls finish, workflows pause between turns + rollout_engine.close_gate() + await rollout_engine.wait_for_drain() + else: + # Stop dispatching new prompts, let all workflows finish naturally + coordinator.pause_generation() + await self._wait_for_all_workflows_idle() + + trainer_state.policy_version = coordinator.policy_version + 1 + await self.backend.on_policy_updated(trainer_state) + rollout_engine.weight_version = trainer_state.policy_version + coordinator.on_sync_complete() + + if self.async_config.partial_rollout: + rollout_engine.open_gate() + else: + coordinator.resume_generation() + + async def _wait_for_all_workflows_idle(self) -> None: + """Wait for all n_parallel_tasks workflows to return to the pool.""" + pool = self.agent_workflow_engine + while pool.workflow_queue.qsize() < pool.n_parallel_tasks: + await asyncio.sleep(0.1) + + async def _validate_async_with_pause(self, trainer_state: TrainerState, coordinator: SyncCoordinator) -> dict: + """Validation with dispatch-level pause. Waits for workflows to drain, then runs validation.""" + coordinator.pause_generation() + await self._wait_for_all_workflows_idle() + try: + return await self._validate_async(trainer_state) + finally: + coordinator.resume_generation() + async def _validate_async(self, trainer_state: TrainerState) -> dict: """Validate the model (async implementation).""" n_val_samples = self.rllm_config.rollout.n_val @@ -724,7 +798,7 @@ def __init__( train_dataset: Dataset | None = None, val_dataset: Dataset | None = None, workflow_args: dict | None = None, - backend: Literal["verl", "tinker"] = "verl", + backend: Literal["verl", "tinker", "fireworks"] = "verl", **kwargs, ): if backend == "verl": @@ -749,6 +823,19 @@ def __init__( workflow_args=workflow_args, **kwargs, ) + elif backend == "fireworks": + from rllm.trainer.fireworks.fireworks_launcher import ( + FireworksTrainerLauncher, + ) + + self.launcher = FireworksTrainerLauncher( + config=config, + workflow_class=workflow_class, + train_dataset=train_dataset, + val_dataset=val_dataset, + workflow_args=workflow_args, + **kwargs, + ) def train(self): self.launcher.train() From f9f01e584d5a92c5d6a6216c924177f2ce0f6a99 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:21:52 -0500 Subject: [PATCH 04/21] Refactor chat parser and migrate experimental rollout to engine (#435) * start refactoring * revert chat template parser and override tinker parser test * revert and fix chat parser test * refactor tinker engine to use tinker parser * deprecate bypass renderer mentions * move experimental rollout out --- docs/experimental/rllm-and-backend-config.md | 1 - .../train_countdown_distill_tinker.sh | 1 - .../opsd/train_deepmath_distill_tinker.sh | 1 - .../train_deepmath_distill_tinker.py | 1 - .../train_deepmath_distill_tinker.sh | 1 - rllm/engine/agent_sdk_engine.py | 4 +- rllm/engine/agent_workflow_engine.py | 4 +- rllm/engine/rollout/__init__.py | 23 +- .../rollout/completer.py | 6 +- rllm/engine/rollout/rollout_engine.py | 23 +- rllm/engine/rollout/tinker_engine.py | 387 ++++++++--------- .../{experimental => engine}/rollout/types.py | 3 +- rllm/engine/rollout/verl_engine.py | 66 ++- .../config/rllm/backend/tinker.yaml | 5 +- .../engine/unified_workflow_engine.py | 4 +- rllm/experimental/protocol.py | 2 +- rllm/experimental/rollout/__init__.py | 31 +- rllm/experimental/rollout/rollout_engine.py | 87 ---- rllm/experimental/rollout/tinker_engine.py | 350 --------------- rllm/experimental/rollout/verl_engine.py | 138 ------ .../test_examples/opsd/math_opsd_workflow.py | 4 +- rllm/experimental/unified_trainer.py | 2 +- rllm/experimental/verl/verl_backend.py | 6 +- rllm/parser/__init__.py | 9 + rllm/parser/tinker_parser.py | 400 ++++++++++++++++++ rllm/parser/utils.py | 11 + rllm/trainer/config/tinker_rl_trainer.yaml | 1 - rllm/trainer/tinker/tinker_backend.py | 4 +- rllm/trainer/tinker/transform.py | 4 +- tests/parser/conftest.py | 54 +++ tests/parser/test_chat_parser.py | 2 +- tests/parser/test_tinker_parser.py | 224 ++++++++++ 32 files changed, 997 insertions(+), 862 deletions(-) rename rllm/{experimental => engine}/rollout/completer.py (96%) rename rllm/{experimental => engine}/rollout/types.py (92%) delete mode 100644 rllm/experimental/rollout/rollout_engine.py delete mode 100644 rllm/experimental/rollout/tinker_engine.py delete mode 100644 rllm/experimental/rollout/verl_engine.py create mode 100644 rllm/parser/tinker_parser.py create mode 100644 tests/parser/conftest.py create mode 100644 tests/parser/test_tinker_parser.py diff --git a/docs/experimental/rllm-and-backend-config.md b/docs/experimental/rllm-and-backend-config.md index 4baebd341..2f9fc8de5 100644 --- a/docs/experimental/rllm-and-backend-config.md +++ b/docs/experimental/rllm-and-backend-config.md @@ -238,7 +238,6 @@ This file contains: | `rollout_engine.reasoning_effort` | `str` | `medium` | Reasoning effort mode | | `rollout_engine.accumulate_reasoning` | `bool` | `false` | Whether to accumulate reasoning across steps | | `rollout_engine.disable_thinking` | `bool` | `false` | Whether to disable thinking tokens | -| `rollout_engine.bypass_render_with_parser` | `bool` | `false` | Whether to bypass render parsing | | `rollout_engine.renderer_name` | `str | null` | `null` | Optional renderer name | | `data.max_prompt_length` | `int` | `2048` | Max prompt length | | `data.max_response_length` | `int` | `2048` | Max response length | diff --git a/examples/countdown/train_countdown_distill_tinker.sh b/examples/countdown/train_countdown_distill_tinker.sh index 7b3a17d5f..1107a312d 100644 --- a/examples/countdown/train_countdown_distill_tinker.sh +++ b/examples/countdown/train_countdown_distill_tinker.sh @@ -24,4 +24,3 @@ python -m examples.countdown.train_countdown_tinker \ trainer.test_freq=10 \ trainer.save_freq=1000 \ trainer.default_local_dir='./outputs/countdown-distill-tinker-8b' \ - rollout_engine.bypass_render_with_parser=True diff --git a/examples/math_distill/opsd/train_deepmath_distill_tinker.sh b/examples/math_distill/opsd/train_deepmath_distill_tinker.sh index cf3a8492d..43c2c74ed 100644 --- a/examples/math_distill/opsd/train_deepmath_distill_tinker.sh +++ b/examples/math_distill/opsd/train_deepmath_distill_tinker.sh @@ -25,5 +25,4 @@ python -m examples.math_distill.opsd.train_deepmath_distill_tinker \ training.default_local_dir='./outputs/opsd-deepmath-8b-rllm' \ rllm.algorithm.use_precomputed_advantage=true \ rllm.algorithm.loss_fn=importance_sampling \ - rollout_engine.bypass_render_with_parser=True \ rllm.workflow.n_parallel_tasks=512 diff --git a/examples/math_distill/train_deepmath_distill_tinker.py b/examples/math_distill/train_deepmath_distill_tinker.py index d4dc5f343..fb2628721 100644 --- a/examples/math_distill/train_deepmath_distill_tinker.py +++ b/examples/math_distill/train_deepmath_distill_tinker.py @@ -26,7 +26,6 @@ def main(config: DictConfig): tokenizer=teacher_tokenizer, service_client=teacher_service_client, sampling_client=teacher_sampling_client, - bypass_render_with_parser=True, ) trainer = AgentTrainer( diff --git a/examples/math_distill/train_deepmath_distill_tinker.sh b/examples/math_distill/train_deepmath_distill_tinker.sh index 26efe10dc..69a769592 100644 --- a/examples/math_distill/train_deepmath_distill_tinker.sh +++ b/examples/math_distill/train_deepmath_distill_tinker.sh @@ -25,6 +25,5 @@ python -m examples.math_distill.train_deepmath_distill_tinker \ training.default_local_dir='./outputs/deepmath-distill-8b-32b-unified' \ rllm.algorithm.use_precomputed_advantage=true \ rllm.algorithm.loss_fn=importance_sampling \ - rollout_engine.bypass_render_with_parser=False \ rollout_engine.renderer_name=qwen3 \ rllm.workflow.n_parallel_tasks=512 diff --git a/rllm/engine/agent_sdk_engine.py b/rllm/engine/agent_sdk_engine.py index 393829314..ed10a30bc 100644 --- a/rllm/engine/agent_sdk_engine.py +++ b/rllm/engine/agent_sdk_engine.py @@ -444,11 +444,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": self.rollout_engine.wake_up() if batch.meta_info.get("validate", False): - self.rollout_engine.validate = True + self.rollout_engine.is_validation = True tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() episodes = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes - self.rollout_engine.validate = False + self.rollout_engine.is_validation = False if isinstance(self.rollout_engine, VerlEngine): await self.rollout_engine.sleep() diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index f2ea8f6b0..3bea843e7 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -208,14 +208,14 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": is_validation = batch.meta_info.get("validate", False) if is_validation: - self.rollout_engine.validate = True + self.rollout_engine.is_validation = True self.current_mode = "val" else: self.current_mode = "train" tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes - self.rollout_engine.validate = False + self.rollout_engine.is_validation = False await self.rollout_engine.sleep() diff --git a/rllm/engine/rollout/__init__.py b/rllm/engine/rollout/__init__.py index 47995ca85..471682f61 100644 --- a/rllm/engine/rollout/__init__.py +++ b/rllm/engine/rollout/__init__.py @@ -1,11 +1,26 @@ -# Avoid importing concrete engines at module import time to prevent circular imports +from typing import TYPE_CHECKING + from .rollout_engine import ModelOutput, RolloutEngine +from .types import TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput, VerlTokenInput, VerlTokenOutput + +if TYPE_CHECKING: + from .tinker_engine import TinkerEngine + from .verl_engine import VerlEngine __all__ = [ "ModelOutput", "RolloutEngine", "OpenAIEngine", + "TinkerEngine", "VerlEngine", + # Token types + "TokenInput", + "TokenOutput", + "TinkerTokenInput", + "TinkerTokenOutput", + "VerlTokenInput", + "VerlTokenOutput", + "Tokenizer", ] @@ -14,6 +29,10 @@ def __getattr__(name): from .openai_engine import OpenAIEngine as _OpenAIEngine return _OpenAIEngine + if name == "TinkerEngine": + from .tinker_engine import TinkerEngine as _TinkerEngine + + return _TinkerEngine if name == "VerlEngine": try: from .verl_engine import VerlEngine as _VerlEngine @@ -21,4 +40,4 @@ def __getattr__(name): return _VerlEngine except Exception: raise AttributeError(name) from None - raise AttributeError(name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/experimental/rollout/completer.py b/rllm/engine/rollout/completer.py similarity index 96% rename from rllm/experimental/rollout/completer.py rename to rllm/engine/rollout/completer.py index 4890e0a62..0aab94471 100644 --- a/rllm/experimental/rollout/completer.py +++ b/rllm/engine/rollout/completer.py @@ -12,8 +12,8 @@ from typing import Any from rllm.agents.agent import Step -from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput +from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine +from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput from rllm.parser import ChatTemplateParser @@ -84,7 +84,7 @@ def __init__(self, rollout_engine: RolloutEngine): raise ValueError(f"The rollout engine {cls_name} does not support token-in-token-out") # we also require the rollout engine has a chat parser and a tokenizer if rollout_engine.chat_parser is None or rollout_engine.tokenizer is None: - raise ValueError("The rollout engine must have a chat parser and a tokenizer. For Tinker engine, make sure you have set bypass_render_with_parser=True.") + raise ValueError("The rollout engine must have a chat parser and a tokenizer.") self.tokenizer = rollout_engine.tokenizer self.chat_parser = rollout_engine.chat_parser diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index 7f3895429..74ccd8b73 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput +from rllm.parser import ChatTemplateParser from rllm.tools.tool_base import ToolCall @@ -9,7 +11,7 @@ class ModelOutput: content: str | None = None reasoning: str | None = None tool_calls: list[ToolCall] | None = None - prompt_ids: list[int] | None = None + prompt_ids: TokenInput | None = None completion_ids: list[int] | None = None multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs @@ -53,12 +55,31 @@ def from_dict(cls, data: dict): class RolloutEngine: + chat_parser: ChatTemplateParser | None = None + tokenizer: Tokenizer | None = None + is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks + def __init__(self, *args, **kwargs): pass async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: raise NotImplementedError("get_model_response is not implemented") + def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: + """ + Assemble model output from a token output. + """ + raise NotImplementedError("assemble_model_output is not implemented") + + async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TokenOutput: + """Obtain the token output from the given token input.""" + raise NotImplementedError("get_token_output_from_token_input is not implemented") + + @property + def supports_token_in_token_out(self) -> bool: + """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" + return False + async def wake_up(self): pass diff --git a/rllm/engine/rollout/tinker_engine.py b/rllm/engine/rollout/tinker_engine.py index c6e35e211..12de041e2 100644 --- a/rllm/engine/rollout/tinker_engine.py +++ b/rllm/engine/rollout/tinker_engine.py @@ -1,37 +1,76 @@ -import json +from typing import cast import tinker from tinker.types import ModelInput from tinker_cookbook import model_info, renderers +from typing_extensions import override # need to use typing_extensions for python < 3.12 from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.parser import ChatTemplateParser -from rllm.tools.tool_base import ToolCall +from rllm.engine.rollout.types import ImageProcessor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput +from rllm.parser.tinker_parser import TinkerChatTemplateParser from rllm.workflows import TerminationEvent, TerminationReason +""" +Utility functions for Tinker engine. Partly borrowed from +https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py +""" + + +def _flat_token_input_to_model_input(token_input: TinkerTokenInput) -> ModelInput: + """Convert a flat token input to a ModelInput.""" + if not token_input: # empty list + return ModelInput(chunks=[]) + + out: list[tinker.ModelInputChunk] = [] + current_text_chunk: list[int] = [] + + def flush_text_chunk(): + if current_text_chunk: + out.append(tinker.EncodedTextChunk(tokens=current_text_chunk)) + current_text_chunk.clear() + + for elem in token_input: + if isinstance(elem, int): + current_text_chunk.append(elem) + else: + flush_text_chunk() + out.append(elem) + + flush_text_chunk() # final clear up + return tinker.ModelInput(chunks=out) + + +def _flat_token_input_length(token_input: TokenInput) -> int: + """Get the length of a flat token input. This nicely handles both text and image inputs""" + length = 0 + for elem in token_input: + if isinstance(elem, int): + length += 1 + else: + length += elem.length + return length + class TinkerEngine(RolloutEngine): """ RolloutEngine implementation using Tinker for model inference. + + Wraps the tinker renderer with a TinkerChatTemplateParser, which provides + unified prompt building (including tool spec injection) and response parsing + (content, reasoning, tool_calls). """ def __init__( self, model_name: str, - tokenizer, + tokenizer: Tokenizer, service_client: tinker.ServiceClient, - sampling_client: tinker.SamplingClient = None, + base_url: str | None = None, max_prompt_length: int = 4096, max_response_length: int = 4096, max_model_length: int = 32768, sampling_params: dict | None = None, - val_sampling_params: dict | None = None, - bypass_render_with_parser: bool = False, - processor=None, - image_processor=None, - disable_thinking: bool = False, - accumulate_reasoning: bool = False, - reasoning_effort: str = "medium", + image_processor: ImageProcessor | None = None, renderer_name: str | None = None, **kwargs, ): @@ -42,55 +81,42 @@ def __init__( model_name: Name of the model to use tokenizer: Tokenizer for encoding/decoding service_client: Tinker ServiceClient instance - sampling_client: Tinker SamplingClient instance + base_url: Tinker service URL (default = null for local) max_prompt_length: Maximum prompt length in tokens max_response_length: Maximum response length in tokens max_model_length: Maximum total length (prompt + response) in tokens - sampling_params: Default sampling parameters for training (temperature, top_p, etc.) - val_sampling_params: Sampling parameters for validation (defaults to sampling_params if not provided) - bypass_render_with_parser: If True, use ChatTemplateParser instead of Tinker's renderer - processor: Optional processor for multimodal models (used when bypass_render_with_parser=True) + sampling_params: Default sampling parameters (temperature, top_p, etc.) image_processor: Optional image processor for vision-language models (used with renderer) - disable_thinking: Whether to disable thinking in generation prompt (used when bypass_render_with_parser=True) - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) - renderer_name: Override renderer name (None = auto-detect from model) + renderer_name: Optional renderer name to use (None = auto-detect from model) + kwargs: Additional keyword arguments + - strip_thinking_from_history: Whether to strip thinking from history (only for Qwen3Renderer) """ + self.base_url = base_url self.model_name = model_name self.max_prompt_length = max_prompt_length self.max_response_length = max_response_length - self.max_model_length = max_model_length - 1 # Reserve 1 token for logprob computation + self.max_model_length = max_model_length - 1 self.tokenizer = tokenizer - self.sampling_params = sampling_params or {} - self.val_sampling_params = val_sampling_params or self.sampling_params - self.validate = False - self.bypass_render_with_parser = bypass_render_with_parser - self.accumulate_reasoning = accumulate_reasoning - self.reasoning_effort = reasoning_effort + self.train_sampling_params = dict(sampling_params.get("train", {})) if sampling_params else {} + self.val_sampling_params = dict(sampling_params.get("val", {})) if sampling_params else {} # Initialize Tinker service client self.service_client = service_client - if bypass_render_with_parser: - self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=disable_thinking) - self.renderer = None - if hasattr(self.chat_parser, "stop_sequences") and self.chat_parser.stop_sequences: - self.stop_sequences = self.chat_parser.stop_sequences - elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: - self.stop_sequences = [tokenizer.eos_token] - else: - raise ValueError("No stop sequences found for tokenizer or chat parser") - else: - # Use explicit renderer_name if provided, otherwise auto-detect - renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) - # Pass image_processor for VLM support with Tinker renderer - self.renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) - self.chat_parser = None - self.stop_sequences = self.renderer.get_stop_sequences() - - # Sampling client can be set later via set_sampling_client() - self.sampling_client = sampling_client + # Initialize the renderer and wrap with TinkerChatTemplateParser + renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) + renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) + + if "strip_thinking_from_history" in kwargs and isinstance(kwargs["strip_thinking_from_history"], bool) and hasattr(renderer, "strip_thinking_from_history"): + renderer.strip_thinking_from_history = kwargs["strip_thinking_from_history"] + + self.chat_parser: TinkerChatTemplateParser = TinkerChatTemplateParser(renderer) + self.stop_sequences = self.chat_parser.stop_sequences - def set_sampling_client(self, sampling_client): + # Sampling client will be set via set_sampling_client() + self.sampling_client: tinker.SamplingClient | None = None + + def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: """ Set the sampling client for inference. @@ -99,34 +125,6 @@ def set_sampling_client(self, sampling_client): """ self.sampling_client = sampling_client - def _convert_images_to_content_list(self, messages: list[dict]) -> list[dict]: - """ - Convert messages from standard format to Tinker renderer format. - - Standard format: {"role": "user", "content": "text", "images": [PIL.Image]} - Tinker format: {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "..."}]} - - Args: - messages: List of messages in standard format - - Returns: - List of messages in Tinker renderer format - """ - converted = [] - for msg in messages: - if "images" in msg and msg["images"]: - # Convert to content list format - content_list = [] - for img in msg["images"]: - content_list.append({"type": "image", "image": img}) - content_list.append({"type": "text", "text": msg.get("content", "")}) - converted.append({**msg, "content": content_list}) - # Remove the images key since it's now in content - del converted[-1]["images"] - else: - converted.append(msg) - return converted - def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> int: """ Prepare max_tokens parameter, adjusting for max_model_length if needed. @@ -149,157 +147,80 @@ def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> return max_tokens - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - """ - Generate model response for a given set of messages. - - Args: - messages: List of message dictionaries (OpenAI format) - **kwargs: Additional parameters including: - - application_id: Session/application ID for tracing - - validate: Whether this is validation (for greedy decoding) - - enforce_max_prompt_length: Whether to enforce max prompt length - - tools: List of tools (used when bypass_render_with_parser=True) - - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) + @property + def supports_token_in_token_out(self) -> bool: + """Tinker sampling client does support returning prompt_ids, so this is true.""" + return True - Returns: - ModelOutput with generated text and metadata + @override + async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TinkerTokenOutput: + """ + Generate a sampled sequence from a given token input. """ + token_input = cast(TinkerTokenInput, token_input) if self.sampling_client is None: raise RuntimeError("Sampling client not set. Call set_sampling_client() first.") - # Extract kwargs - kwargs.pop("application_id", None) - validate = kwargs.pop("validate", False) or self.validate - enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - sampling_params = self.val_sampling_params if validate else self.sampling_params + input_length = _flat_token_input_length(token_input) - # Extract parser-specific kwargs - tools = kwargs.pop("tools", []) - accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) - reasoning_effort = kwargs.pop("reasoning_effort", self.reasoning_effort) - - if self.bypass_render_with_parser: - # Use ChatTemplateParser - prompt = self.chat_parser.parse( - messages, - add_generation_prompt=True, - is_first_msg=True, - tools=tools, - reasoning_effort=reasoning_effort, - accumulate_reasoning=accumulate_reasoning, - ) - prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) - prompt_length = len(prompt_ids) - - # Check prompt length - if enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length): - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - - # Dynamically adjust max_tokens based on prompt length - default_max_tokens = sampling_params.get("max_tokens", self.max_response_length) - requested_max_tokens = kwargs.get("max_tokens", kwargs.get("max_new_tokens", default_max_tokens)) - max_tokens = self._prepare_max_tokens(requested_max_tokens, prompt_length) - - # Prepare sampling params (override defaults with kwargs) - sampling_params = tinker.types.SamplingParams( - max_tokens=max_tokens, - stop=self.stop_sequences, - temperature=kwargs.get("temperature", sampling_params.get("temperature", 1.0)), - top_p=kwargs.get("top_p", sampling_params.get("top_p", 1.0)), - ) - - # Convert prompt to Tinker prompt format - tinker_prompt = ModelInput.from_ints(prompt_ids) - - # Call Tinker sampling API - sample_response = await self.sampling_client.sample_async( - prompt=tinker_prompt, - num_samples=1, - sampling_params=sampling_params, - ) - - # Extract response tokens and logprobs - response_tokens = sample_response.sequences[0].tokens - logprobs = sample_response.sequences[0].logprobs - - # Parse response using parser - parsed_output = self.chat_parser.parse_completion(response_tokens) - - content = parsed_output.get("content", "") - reasoning = parsed_output.get("reasoning", "") - tool_calls = parsed_output.get("tool_calls", []) - - # Decode full text - completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) - else: - # Use Tinker renderer (original behavior) - # Convert standard image format to Tinker renderer format - converted_messages = self._convert_images_to_content_list(messages) - # Build prompt using renderer (converts messages to Tinker prompt) - tinker_prompt = self.renderer.build_generation_prompt(converted_messages) - - # For VLM prompts with ImageChunks, to_ints() may not be supported - try: - prompt_ids = tinker_prompt.to_ints() - prompt_length = len(prompt_ids) - except ValueError: - # Prompt contains ImageChunks - skip length enforcement for VLM - prompt_ids = [] - prompt_length = 0 - - # Check prompt length (only for text-only prompts) - if prompt_length > 0 and enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length): - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - - # Dynamically adjust max_tokens based on prompt length - default_max_tokens = sampling_params.get("max_tokens", self.max_response_length) - requested_max_tokens = kwargs.get("max_tokens", kwargs.get("max_new_tokens", default_max_tokens)) - max_tokens = self._prepare_max_tokens(requested_max_tokens, prompt_length) if prompt_length > 0 else requested_max_tokens - - # Prepare sampling params (override defaults with kwargs) - sampling_params = tinker.types.SamplingParams( - max_tokens=max_tokens, - stop=self.stop_sequences, - temperature=kwargs.get("temperature", sampling_params.get("temperature", 1.0)), - top_p=kwargs.get("top_p", sampling_params.get("top_p", 1.0)), - ) - - # Call Tinker sampling API - sample_response = await self.sampling_client.sample_async( - prompt=tinker_prompt, - num_samples=1, - sampling_params=sampling_params, - ) - - # Extract response tokens and logprobs - response_tokens = sample_response.sequences[0].tokens - logprobs = sample_response.sequences[0].logprobs - - # Parse response using renderer - parsed_msg, _ = self.renderer.parse_response(response_tokens) - raw_content = parsed_msg["content"] - tool_calls = [] - for tc in parsed_msg.get("tool_calls", []): - try: - tool_calls.append(ToolCall(name=tc.function.name, arguments=json.loads(tc.function.arguments))) - except (json.JSONDecodeError, AttributeError): - continue - - if isinstance(raw_content, list): - reasoning = next((p["thinking"] for p in raw_content if p["type"] == "thinking"), "") - content = next((p["text"] for p in raw_content if p["type"] == "text"), "") - else: - content = raw_content - reasoning = "" + enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) + if enforce_max_prompt_length and input_length > min(self.max_prompt_length, self.max_model_length): + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + + # prepare sampling params + sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() + + requested_max_tokens = kwargs.pop("max_tokens", kwargs.pop("max_new_tokens", self.max_response_length)) + requested_max_tokens = sampling_params.pop("max_tokens", requested_max_tokens) + max_tokens = self._prepare_max_tokens(requested_max_tokens, input_length) + + if "temperature" in kwargs: + sampling_params["temperature"] = kwargs["temperature"] + if "top_p" in kwargs: + sampling_params["top_p"] = kwargs["top_p"] + if "top_k" in kwargs: + sampling_params["top_k"] = kwargs["top_k"] + + tinker_sampling_params = tinker.types.SamplingParams( + max_tokens=max_tokens, + stop=self.stop_sequences, # type: ignore + **sampling_params, + ) + # call sampling client + model_input = _flat_token_input_to_model_input(token_input) + sample_response: tinker.SampleResponse = await self.sampling_client.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=tinker_sampling_params, + ) - # Decode full text - completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) + # return sampled sequence from sample response + return sample_response.sequences[0] - # Determine finish reason - finish_reason = "stop" - if len(response_tokens) >= sampling_params.max_tokens: - finish_reason = "length" + @override + def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: + """ + Assemble model output from a sampled sequence. + """ + sampled_sequence = cast(TinkerTokenOutput, token_output) + response_tokens, logprobs = sampled_sequence.tokens, sampled_sequence.logprobs + + # Parse response using parser (handles content, reasoning, tool_calls) + parsed_output = self.chat_parser.parse_completion(response_tokens) + content = parsed_output.get("content", "") + reasoning = parsed_output.get("reasoning", "") + tool_calls = parsed_output.get("tool_calls", []) + + # decode full text + completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) # type: ignore + finish_reason = sampled_sequence.stop_reason + # special handling for prompt ids, we will break any EncodedTextChunk into ints + prompt_ids = [] + for elem in token_input: + if isinstance(elem, tinker.EncodedTextChunk): + prompt_ids.extend(elem.tokens) + else: + prompt_ids.append(elem) return ModelOutput( text=completion_text, @@ -309,11 +230,39 @@ async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutpu prompt_ids=prompt_ids, completion_ids=response_tokens, logprobs=logprobs, - prompt_length=prompt_length, + prompt_length=_flat_token_input_length(token_input), completion_length=len(response_tokens), finish_reason=finish_reason, ) + @override + async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + """ + Generate model response for a given set of messages. + + Args: + messages: List of message dictionaries (OpenAI format) + **kwargs: Additional parameters including: + - application_id: Session/application ID for tracing + - enforce_max_prompt_length: Whether to enforce max prompt length + - tools: List of tools for tool-augmented generation + + Returns: + ModelOutput with generated text and metadata + """ + # Extract unused kwargs + kwargs.pop("application_id", None) + + # Extract tools + tools = kwargs.pop("tools", []) + + # Build prompt using TinkerChatTemplateParser (handles tools, images, etc.) + tinker_prompt = self.chat_parser.build_prompt(messages, tools=tools) + token_input: TinkerTokenInput = tinker_prompt.chunks + + sampled_sequence = await self.get_token_output_from_token_input(token_input=token_input, **kwargs) + return self.assemble_model_output(token_input=token_input, token_output=sampled_sequence) + async def compute_logprobs(self, ids: list[int]) -> list[float]: ids = ids[: self.max_model_length] return await self.sampling_client.compute_logprobs_async(ModelInput.from_ints(ids)) diff --git a/rllm/experimental/rollout/types.py b/rllm/engine/rollout/types.py similarity index 92% rename from rllm/experimental/rollout/types.py rename to rllm/engine/rollout/types.py index 22b30195b..d52466d2d 100644 --- a/rllm/experimental/rollout/types.py +++ b/rllm/engine/rollout/types.py @@ -17,7 +17,8 @@ Processor: TypeAlias = Any ImageProcessor: TypeAlias = Any -# Tinker types. See https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py +# Tinker types. +# See https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py # for the rationale behind "FlatObElem" and "FlatOb" types. try: from tinker.types import ModelInputChunk, SampledSequence diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py index 5a19e07c0..98ddc5b13 100644 --- a/rllm/engine/rollout/verl_engine.py +++ b/rllm/engine/rollout/verl_engine.py @@ -1,16 +1,19 @@ import asyncio import uuid +from typing import cast +from omegaconf import DictConfig +from typing_extensions import override from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AsyncLLMServerManager -from verl.workers.rollout.replica import TokenOutput from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine +from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput, VerlTokenOutput from rllm.parser import ChatTemplateParser from rllm.workflows import TerminationEvent, TerminationReason class VerlEngine(RolloutEngine): - def __init__(self, config, rollout_manager, tokenizer, processor=None, **kwargs): + def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs): self.config = config if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: @@ -43,21 +46,35 @@ def __init__(self, config, rollout_manager, tokenizer, processor=None, **kwargs) print(f"train_sampling_params: {self.train_sampling_params}") print(f"val_sampling_params: {self.val_sampling_params}") - self.validate = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks_verl + @property + def supports_token_in_token_out(self) -> bool: + return True - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + @override + async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> VerlTokenOutput: + token_input = cast(list[int], token_input) + + input_length = len(token_input) application_id = kwargs.pop("application_id", str(uuid.uuid4())) - validate = self.validate or kwargs.pop("validate", False) enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - # these go to the parser - tools = kwargs.pop("tools", []) - accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + if enforce_max_prompt_length and input_length > self.max_prompt_length: + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - sampling_params = self.val_sampling_params.copy() if self.validate or validate else self.train_sampling_params.copy() + sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() sampling_params.update(kwargs) - max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) + # starting from verl 0.7.0, we can pass in per-turn max_tokens into the sampling_params + sampling_params["max_tokens"] = max_tokens + + token_output = await self.server_manager.generate(request_id=application_id, prompt_ids=token_input, sampling_params=sampling_params) + return token_output + + @override + async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + # these go to the parser + tools = kwargs.pop("tools", []) + accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] @@ -73,19 +90,26 @@ async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutpu multi_modal_inputs = None prompt_ids = request_prompt_ids - prompt_length = len(prompt_ids) - if enforce_max_prompt_length and prompt_length > self.max_prompt_length: - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + token_output: TokenOutput = await self.get_token_output_from_token_input(token_input=request_prompt_ids, **kwargs) + extra_kwargs = dict(prompt_ids=prompt_ids, multi_modal_inputs=multi_modal_inputs) + return self.assemble_model_output(token_input=request_prompt_ids, token_output=token_output, **extra_kwargs) - token_output: TokenOutput = await self.server_manager.generate(request_id=application_id, prompt_ids=request_prompt_ids, image_data=image_data, sampling_params=sampling_params) # type: ignore - completion_ids: list[int] = token_output.token_ids - logprobs: list[float] = token_output.log_probs + @override + def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput, **kwargs) -> ModelOutput: + prompt_ids = kwargs.pop("prompt_ids", None) + multi_modal_inputs = kwargs.pop("multi_modal_inputs", None) + prompt_length = len(prompt_ids) if prompt_ids is not None else 0 - finish_reason = "stop" - if len(completion_ids) >= max_tokens: - finish_reason = "length" - completion_ids = completion_ids[:max_tokens] - logprobs = logprobs[:max_tokens] + token_output = cast(VerlTokenOutput, token_output) + completion_ids = token_output.token_ids + logprobs = token_output.log_probs + + # convert the stop reason from verl back to the standard finish reason TODO(listar2000): check backward-compatibility + reason_mapping = {"aborted": "abort", "completed": "stop"} + if token_output.stop_reason is not None: + finish_reason = reason_mapping.get(token_output.stop_reason, token_output.stop_reason) + else: + finish_reason = "stop" completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) # TODO: implement parse_completion for the standard parser diff --git a/rllm/experimental/config/rllm/backend/tinker.yaml b/rllm/experimental/config/rllm/backend/tinker.yaml index 5fba90ed9..184b0c8a4 100644 --- a/rllm/experimental/config/rllm/backend/tinker.yaml +++ b/rllm/experimental/config/rllm/backend/tinker.yaml @@ -59,10 +59,7 @@ agent: # Tinker Engine Configuration rollout_engine: - reasoning_effort: "medium" - accumulate_reasoning: false - disable_thinking: false - bypass_render_with_parser: false + strip_thinking_from_history: true renderer_name: null # Data Configuration diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 8f2e7e80c..96d2e04e9 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -10,7 +10,7 @@ from tqdm import tqdm from rllm.agents.agent import Episode -from rllm.experimental.rollout import RolloutEngine +from rllm.engine.rollout import RolloutEngine from rllm.utils import colorful_print from rllm.workflows.workflow import TerminationReason, Workflow @@ -232,7 +232,7 @@ async def execute_tasks_verl(self, batch: DataProto, is_validation: bool = False Returns: list[Episode]: List of completed episodes. """ - from rllm.experimental.rollout import VerlEngine + from rllm.engine.rollout import VerlEngine assert isinstance(self.rollout_engine, VerlEngine), "Rollout engine must be a VerlEngine to invoke execute_tasks_verl" await self.rollout_engine.wake_up() diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 0c8491dbe..17672bf41 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -16,8 +16,8 @@ from rllm.agents.agent import Episode from rllm.data import Dataset +from rllm.engine.rollout import RolloutEngine from rllm.experimental.common.advantage import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups -from rllm.experimental.rollout import RolloutEngine if TYPE_CHECKING: from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine diff --git a/rllm/experimental/rollout/__init__.py b/rllm/experimental/rollout/__init__.py index 6e5c4d681..7fe19012a 100644 --- a/rllm/experimental/rollout/__init__.py +++ b/rllm/experimental/rollout/__init__.py @@ -1,19 +1,20 @@ -from typing import TYPE_CHECKING - -from .rollout_engine import ModelOutput, RolloutEngine -from .types import TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput, VerlTokenInput, VerlTokenOutput - -if TYPE_CHECKING: - from .tinker_engine import TinkerEngine - from .verl_engine import VerlEngine +# Backward compatibility: re-export from canonical location +from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine # noqa: F401 +from rllm.engine.rollout.types import ( # noqa: F401 + TinkerTokenInput, + TinkerTokenOutput, + TokenInput, + Tokenizer, + TokenOutput, + VerlTokenInput, + VerlTokenOutput, +) __all__ = [ "ModelOutput", - # Rollout engines "RolloutEngine", "TinkerEngine", "VerlEngine", - # Token input/output types "TokenInput", "TokenOutput", "TinkerTokenInput", @@ -25,12 +26,16 @@ def __getattr__(name): + # Lazy imports for engines with heavy dependencies if name == "TinkerEngine": - from .tinker_engine import TinkerEngine as _TinkerEngine + from rllm.engine.rollout.tinker_engine import TinkerEngine as _TinkerEngine return _TinkerEngine if name == "VerlEngine": - from .verl_engine import VerlEngine as _VerlEngine + try: + from rllm.engine.rollout.verl_engine import VerlEngine as _VerlEngine - return _VerlEngine + return _VerlEngine + except Exception: + raise AttributeError(name) from None raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py deleted file mode 100644 index ceb9c603e..000000000 --- a/rllm/experimental/rollout/rollout_engine.py +++ /dev/null @@ -1,87 +0,0 @@ -from dataclasses import dataclass - -from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput -from rllm.parser import ChatTemplateParser -from rllm.tools.tool_base import ToolCall - - -@dataclass -class ModelOutput: - text: str | None = None - content: str | None = None - reasoning: str | None = None - tool_calls: list[ToolCall] | None = None - prompt_ids: TokenInput | None = None - completion_ids: list[int] | None = None - multi_modal_inputs: dict[str, list] | None = None - logprobs: list[float] | None = None # completion logprobs - prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids - prompt_length: int = 0 - completion_length: int = 0 - finish_reason: str | None = None - - def to_dict(self): - return { - "text": self.text, - "content": self.content, - "reasoning": self.reasoning, - "tool_calls": [tool_call.to_dict() for tool_call in self.tool_calls] if self.tool_calls else [], - "prompt_ids": self.prompt_ids, - "completion_ids": self.completion_ids, - "multi_modal_inputs": self.multi_modal_inputs, - "logprobs": self.logprobs, - "prompt_logprobs": self.prompt_logprobs, - "prompt_length": self.prompt_length, - "completion_length": self.completion_length, - "finish_reason": self.finish_reason, - } - - @classmethod - def from_dict(cls, data: dict): - return cls( - text=data.get("text"), - content=data.get("content"), - reasoning=data.get("reasoning"), - tool_calls=[ToolCall(**tool_call) for tool_call in data.get("tool_calls", [])] if data.get("tool_calls") else None, - prompt_ids=data.get("prompt_ids"), - completion_ids=data.get("completion_ids"), - multi_modal_inputs=data.get("multi_modal_inputs"), - logprobs=data.get("logprobs"), - prompt_logprobs=data.get("prompt_logprobs"), - prompt_length=data.get("prompt_length", 0), - completion_length=data.get("completion_length", 0), - finish_reason=data.get("finish_reason"), - ) - - -class RolloutEngine: - chat_parser: ChatTemplateParser | None = None - tokenizer: Tokenizer | None = None - is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks - - def __init__(self, *args, **kwargs): - pass - - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") - - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: - """ - Assemble model output from a token output. - """ - raise NotImplementedError("assemble_model_output is not implemented") - - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TokenOutput: - """Obtain the token output from the given token input.""" - raise NotImplementedError("get_token_output_from_token_input is not implemented") - - async def wake_up(self): - pass - - async def sleep(self): - pass - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False diff --git a/rllm/experimental/rollout/tinker_engine.py b/rllm/experimental/rollout/tinker_engine.py deleted file mode 100644 index 27bf4ea77..000000000 --- a/rllm/experimental/rollout/tinker_engine.py +++ /dev/null @@ -1,350 +0,0 @@ -from typing import Any, cast - -import tinker -from tinker.types import ModelInput -from tinker_cookbook import model_info, renderers -from tinker_cookbook.renderers import Message -from typing_extensions import override # need to use typing_extensions for python < 3.12 - -from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.experimental.rollout.types import ImageProcessor, Processor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput -from rllm.parser import ChatTemplateParser -from rllm.workflows import TerminationEvent, TerminationReason - -""" -Utility functions for Tinker engine. Partly borrowed from -https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py -""" - - -def _flat_token_input_to_model_input(token_input: TinkerTokenInput) -> ModelInput: - """Convert a flat token input to a ModelInput.""" - if not token_input: # empty list - return ModelInput(chunks=[]) - - out: list[tinker.ModelInputChunk] = [] - current_text_chunk: list[int] = [] - - def flush_text_chunk(): - if current_text_chunk: - out.append(tinker.EncodedTextChunk(tokens=current_text_chunk)) - current_text_chunk.clear() - - for elem in token_input: - if isinstance(elem, int): - current_text_chunk.append(elem) - else: - flush_text_chunk() - out.append(elem) - - flush_text_chunk() # final clear up - return tinker.ModelInput(chunks=out) - - -def _flat_token_input_length(token_input: TokenInput) -> int: - """Get the length of a flat token input. This nicely handles both text and image inputs""" - length = 0 - for elem in token_input: - if isinstance(elem, int): - length += 1 - else: - length += elem.length - return length - - -def _parse_tinker_message(message: Message) -> tuple[str, str, list[Any]]: - tinker_content = message["content"] - if isinstance(tinker_content, list): - text_parts, think_parts = [], [] - for part in tinker_content: - if part["type"] == "text": - text_parts.append(part) - elif part["type"] == "thinking": - think_parts.append(part) - content = "\n".join([text["text"] for text in text_parts]) - reasoning = "\n".join([think["thinking"] for think in think_parts]) - else: # no reasoning parsed - content = tinker_content - reasoning = "" - # TODO(listar2000): the Tinker tool_calls is not fully compatible with the rLLM one - tool_calls = message.get("tool_calls", []) - return content, reasoning, tool_calls - - -class TinkerEngine(RolloutEngine): - """ - RolloutEngine implementation using Tinker for model inference. - """ - - def __init__( - self, - base_url: str, - model_name: str, - tokenizer: Tokenizer, - service_client: tinker.ServiceClient, - max_prompt_length: int = 4096, - max_response_length: int = 4096, - max_model_length: int = 32768, - sampling_params: dict | None = None, - bypass_render_with_parser: bool = True, # default to True now - processor: Processor | None = None, - image_processor: ImageProcessor | None = None, - disable_thinking: bool = False, - accumulate_reasoning: bool = False, - reasoning_effort: str = "medium", - renderer_name: str | None = None, - **kwargs, - ): - """ - Initialize TinkerEngine. - - Args: - base_url: Tinker service base URL - model_name: Name of the model to use - tokenizer: Tokenizer for encoding/decoding - service_client: Tinker ServiceClient instance - max_prompt_length: Maximum prompt length in tokens - max_response_length: Maximum response length in tokens - max_model_length: Maximum total length (prompt + response) in tokens - sampling_params: Default sampling parameters (temperature, top_p, etc.) - bypass_render_with_parser: If True, use ChatTemplateParser instead of Tinker's renderer - processor: Optional processor for multimodal models (used when bypass_render_with_parser=True) - image_processor: Optional image processor for vision-language models (used with renderer) - disable_thinking: Whether to disable thinking in generation prompt (used when bypass_render_with_parser=True) - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) - """ - self.base_url = base_url - self.model_name = model_name - self.max_prompt_length = max_prompt_length - self.max_response_length = max_response_length - self.max_model_length = max_model_length - 1 - self.tokenizer = tokenizer - self.bypass_render_with_parser = bypass_render_with_parser - self.accumulate_reasoning = accumulate_reasoning - self.reasoning_effort = reasoning_effort - - self.train_sampling_params = dict(sampling_params.get("train", {})) if sampling_params else {} - self.val_sampling_params = dict(sampling_params.get("val", {})) if sampling_params else {} - # Initialize Tinker service client - self.service_client = service_client - - # Initialize the renderer - renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) - # Pass image_processor for VLM support with Tinker renderer - self.renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) - - if bypass_render_with_parser: - self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=disable_thinking) - if hasattr(self.chat_parser, "stop_sequences") and self.chat_parser.stop_sequences: - self.stop_sequences = self.chat_parser.stop_sequences - elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: - self.stop_sequences = [tokenizer.eos_token_id] - else: - raise ValueError("No stop sequences found for tokenizer or chat parser") - else: - self.chat_parser = None - self.stop_sequences = self.renderer.get_stop_sequences() - - # Sampling client will be set via set_sampling_client() - self.sampling_client = None - - def set_sampling_client(self, sampling_client): - """ - Set the sampling client for inference. - - Args: - sampling_client: Tinker SamplingClient instance - """ - self.sampling_client = sampling_client - - def _convert_images_to_content_list(self, messages: list[dict]) -> list[dict]: - """ - Convert messages from standard format to Tinker renderer format. - - Standard format: {"role": "user", "content": "text", "images": [PIL.Image]} - Tinker format: {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "..."}]} - - Args: - messages: List of messages in standard format - - Returns: - List of messages in Tinker renderer format - """ - converted = [] - for msg in messages: - if "images" in msg and msg["images"]: - # Convert to content list format - content_list = [] - for img in msg["images"]: - content_list.append({"type": "image", "image": img}) - content_list.append({"type": "text", "text": msg.get("content", "")}) - converted.append({**msg, "content": content_list}) - # Remove the images key since it's now in content - del converted[-1]["images"] - else: - converted.append(msg) - return converted - - def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> int: - """ - Prepare max_tokens parameter, adjusting for max_model_length if needed. - - Args: - requested_max_tokens: The requested max_tokens value - prompt_length: The length of the prompt in tokens - - Returns: - Adjusted max_tokens value - """ - max_tokens = requested_max_tokens - - # Adjust for prompt length if max_model_length is set - if self.max_model_length: - remaining = self.max_model_length - prompt_length - if remaining <= max_tokens: - max_tokens = remaining - print(f"Warning: Decreasing max_tokens to {max_tokens} to stay within max_model_length") - - return max_tokens - - @property - def supports_token_in_token_out(self) -> bool: - """Tinker sampling client does support returning prompt_ids, so this is true.""" - return True - - @override - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TinkerTokenOutput: - """ - Generate a sampled sequence from a given token input. - """ - token_input = cast(TinkerTokenInput, token_input) - if self.sampling_client is None: - raise RuntimeError("Sampling client not set. Call set_sampling_client() first.") - - input_length = _flat_token_input_length(token_input) - - enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - if enforce_max_prompt_length and input_length > min(self.max_prompt_length, self.max_model_length): - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - - # prepare sampling params - sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() - - requested_max_tokens = kwargs.pop("max_tokens", kwargs.pop("max_new_tokens", self.max_response_length)) - requested_max_tokens = sampling_params.pop("max_tokens", requested_max_tokens) - max_tokens = self._prepare_max_tokens(requested_max_tokens, input_length) - - if "temperature" in kwargs: - sampling_params["temperature"] = kwargs["temperature"] - if "top_p" in kwargs: - sampling_params["top_p"] = kwargs["top_p"] - if "top_k" in kwargs: - sampling_params["top_k"] = kwargs["top_k"] - - tinker_sampling_params = tinker.types.SamplingParams( - max_tokens=max_tokens, - stop=self.stop_sequences, # type: ignore - **sampling_params, - ) - # call sampling client - model_input = _flat_token_input_to_model_input(token_input) - sample_response: tinker.SampleResponse = await self.sampling_client.sample_async( - prompt=model_input, - num_samples=1, - sampling_params=tinker_sampling_params, - ) - - # return sampled sequence from sample response - return sample_response.sequences[0] - - @override - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: - """ - Assemble model output from a sampled sequence. - """ - sampled_sequence = cast(TinkerTokenOutput, token_output) - response_tokens, logprobs = sampled_sequence.tokens, sampled_sequence.logprobs - - if self.bypass_render_with_parser: - assert self.chat_parser is not None, "chat_parser must be set when bypass_render_with_parser=True" - parsed_output = self.chat_parser.parse_completion(response_tokens) - content = parsed_output.get("content", "") - reasoning = parsed_output.get("reasoning", "") - tool_calls = parsed_output.get("tool_calls", []) - else: - assert isinstance(self.renderer, renderers.Renderer), "self.renderer must be a valid Tinker Renderer" - response_message, _ = self.renderer.parse_response(response_tokens) - content, reasoning, tool_calls = _parse_tinker_message(response_message) - - # decode full text - completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) # type: ignore - finish_reason = sampled_sequence.stop_reason - # special handling for prompt ids, we will break any EncodedTextChunk into ints - prompt_ids = [] - for elem in token_input: - if isinstance(elem, tinker.EncodedTextChunk): - prompt_ids.extend(elem.tokens) - else: - prompt_ids.append(elem) - - return ModelOutput( - text=completion_text, - content=content, - reasoning=reasoning, - tool_calls=tool_calls, - prompt_ids=prompt_ids, - completion_ids=response_tokens, - logprobs=logprobs, - prompt_length=_flat_token_input_length(token_input), - completion_length=len(response_tokens), - finish_reason=finish_reason, - ) - - @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - """ - Generate model response for a given set of messages. - - Args: - messages: List of message dictionaries (OpenAI format) - **kwargs: Additional parameters including: - - application_id: Session/application ID for tracing - - enforce_max_prompt_length: Whether to enforce max prompt length - - tools: List of tools (used when bypass_render_with_parser=True) - - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) - - Returns: - ModelOutput with generated text and metadata - """ - # Extract unused kwargs - kwargs.pop("application_id", None) - - # Extract parser-specific kwargs - tools = kwargs.pop("tools", []) - accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) - reasoning_effort = kwargs.pop("reasoning_effort", self.reasoning_effort) - - if self.bypass_render_with_parser: - # Use ChatTemplateParser - prompt = self.chat_parser.parse( # type: ignore - messages, - add_generation_prompt=True, - is_first_msg=True, - tools=tools, - reasoning_effort=reasoning_effort, - accumulate_reasoning=accumulate_reasoning, - ) - token_input = self.tokenizer.encode(prompt, add_special_tokens=False) # type: ignore - else: - # Use Tinker renderer - # Convert standard image format to Tinker renderer format - converted_messages = self._convert_images_to_content_list(messages) - # Build prompt using renderer - token_input: TinkerTokenInput = self.renderer.build_generation_prompt(converted_messages).chunks # type: ignore - - sampled_sequence = await self.get_token_output_from_token_input(token_input=token_input, **kwargs) - return self.assemble_model_output(token_input=token_input, token_output=sampled_sequence) - - async def compute_logprobs(self, ids: list[int]) -> list[float]: - ids = ids[: self.max_model_length] - return await self.sampling_client.compute_logprobs_async(ModelInput.from_ints(ids)) diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py deleted file mode 100644 index 48d42aa7c..000000000 --- a/rllm/experimental/rollout/verl_engine.py +++ /dev/null @@ -1,138 +0,0 @@ -import asyncio -import uuid -from typing import cast - -from omegaconf import DictConfig -from typing_extensions import override -from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AsyncLLMServerManager - -from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput, VerlTokenOutput -from rllm.parser import ChatTemplateParser -from rllm.workflows import TerminationEvent, TerminationReason - - -class VerlEngine(RolloutEngine): - def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs): - self.config = config - - if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: - raise ValueError(f"VerlEngine only supports vllm or sglang rollout, but got {config.actor_rollout_ref.rollout.name}") - - self.rollout_manager: AgentLoopManager = rollout_manager - self.server_manager = AsyncLLMServerManager(config, server_handles=rollout_manager.server_handles) - self.tokenizer = tokenizer - self.processor = processor - self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=config.get("rllm", {}).get("disable_thinking", False)) - - self.max_prompt_length = config.data.max_prompt_length - self.max_response_length = config.data.max_response_length - self.accumulate_reasoning = config.get("rllm", {}).get("accumulate_reasoning", False) - - self.train_sampling_params = dict( - temperature=0.0 if config.actor_rollout_ref.rollout.do_sample is False else config.actor_rollout_ref.rollout.temperature, - top_k=config.actor_rollout_ref.rollout.top_k, - top_p=config.actor_rollout_ref.rollout.top_p, - logprobs=1, - ) - - self.val_sampling_params = dict( - temperature=0.0 if config.actor_rollout_ref.rollout.val_kwargs.do_sample is False else config.actor_rollout_ref.rollout.val_kwargs.temperature, - top_k=config.actor_rollout_ref.rollout.val_kwargs.top_k, - top_p=config.actor_rollout_ref.rollout.val_kwargs.top_p, - logprobs=1, - ) - - print(f"train_sampling_params: {self.train_sampling_params}") - print(f"val_sampling_params: {self.val_sampling_params}") - - @property - def supports_token_in_token_out(self) -> bool: - return True - - @override - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> VerlTokenOutput: - token_input = cast(list[int], token_input) - - input_length = len(token_input) - application_id = kwargs.pop("application_id", str(uuid.uuid4())) - enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - - if enforce_max_prompt_length and input_length > self.max_prompt_length: - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - - sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() - sampling_params.update(kwargs) - max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) - # starting from verl 0.7.0, we can pass in per-turn max_tokens into the sampling_params - sampling_params["max_tokens"] = max_tokens - - token_output = await self.server_manager.generate(request_id=application_id, prompt_ids=token_input, sampling_params=sampling_params) - return token_output - - @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - # these go to the parser - tools = kwargs.pop("tools", []) - accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) - - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) - request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] - - if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: - image_data = self.chat_parser.process_image_data(messages) # list[PIL.Image.Image] - model_inputs = self.processor(text=[prompt], images=image_data) - prompt_ids = model_inputs.pop("input_ids")[0] # list[int] - model_inputs.pop("attention_mask") - multi_modal_inputs = dict(model_inputs) - else: - image_data = None - multi_modal_inputs = None - prompt_ids = request_prompt_ids - - token_output: TokenOutput = await self.get_token_output_from_token_input(token_input=request_prompt_ids, **kwargs) - extra_kwargs = dict(prompt_ids=prompt_ids, multi_modal_inputs=multi_modal_inputs) - return self.assemble_model_output(token_input=request_prompt_ids, token_output=token_output, **extra_kwargs) - - @override - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput, **kwargs) -> ModelOutput: - prompt_ids = kwargs.pop("prompt_ids", None) - multi_modal_inputs = kwargs.pop("multi_modal_inputs", None) - prompt_length = len(prompt_ids) if prompt_ids is not None else 0 - - token_output = cast(VerlTokenOutput, token_output) - completion_ids = token_output.token_ids - logprobs = token_output.log_probs - - # convert the stop reason from verl back to the standard finish reason TODO(listar2000): check backward-compatibility - reason_mapping = {"aborted": "abort", "completed": "stop"} - if token_output.stop_reason is not None: - finish_reason = reason_mapping.get(token_output.stop_reason, token_output.stop_reason) - else: - finish_reason = "stop" - - completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) - # TODO: implement parse_completion for the standard parser - parsed_output = self.chat_parser.parse_completion(completion_ids) - - return ModelOutput( - text=completion_text, - content=parsed_output["content"], - reasoning=parsed_output["reasoning"], - tool_calls=parsed_output["tool_calls"], - prompt_ids=prompt_ids, - completion_ids=completion_ids, - multi_modal_inputs=multi_modal_inputs, - logprobs=logprobs, - prompt_length=prompt_length, - completion_length=len(completion_ids), - finish_reason=finish_reason, - ) - - async def wake_up(self): - """Wake up all rollout replica instances asynchronously.""" - await asyncio.gather(*[replica.wake_up() for replica in self.rollout_manager.rollout_replicas]) - - async def sleep(self): - """Sleep all rollout replica instances asynchronously.""" - await asyncio.gather(*[replica.sleep() for replica in self.rollout_manager.rollout_replicas]) diff --git a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py index 8f9488591..3cd314158 100644 --- a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py +++ b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py @@ -1,7 +1,7 @@ from rllm.agents.agent import Episode, Trajectory +from rllm.engine.rollout.completer import Completer +from rllm.engine.rollout.rollout_engine import RolloutEngine from rllm.experimental.opsd.workflow_utils import OPSDConfig, opsd_postprocess -from rllm.experimental.rollout.completer import Completer -from rllm.experimental.rollout.rollout_engine import RolloutEngine from rllm.rewards.reward_fn import math_reward_fn from rllm.workflows.workflow import Workflow diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 2d1521362..2496693b9 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -12,6 +12,7 @@ from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset +from rllm.engine.rollout import RolloutEngine from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, @@ -31,7 +32,6 @@ from rllm.experimental.common.visualization import visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine from rllm.experimental.protocol import BackendProtocol -from rllm.experimental.rollout import RolloutEngine from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata from rllm.workflows.workflow import TerminationReason, Workflow diff --git a/rllm/experimental/verl/verl_backend.py b/rllm/experimental/verl/verl_backend.py index 5729a438c..64e4c4a5a 100644 --- a/rllm/experimental/verl/verl_backend.py +++ b/rllm/experimental/verl/verl_backend.py @@ -28,13 +28,13 @@ from rllm.agents.agent import Episode from rllm.data import Dataset +from rllm.engine.rollout import RolloutEngine, VerlEngine from rllm.experimental.common import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, simple_timer, ) from rllm.experimental.protocol import BackendProtocol -from rllm.experimental.rollout import RolloutEngine, VerlEngine from rllm.experimental.verl import compute_advantage_verl, transform_episodes_to_dataproto, update_dataproto_with_advantages if TYPE_CHECKING: @@ -409,10 +409,10 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: return False else: trainer_state.is_training = False - self.rollout_engine.validate = True # type: ignore[attr-defined] + self.rollout_engine.is_validation = True return True async def on_validation_end(self, trainer_state: TrainerState) -> None: """Called at the end of validation.""" trainer_state.is_training = True - self.rollout_engine.validate = False # type: ignore[attr-defined] + self.rollout_engine.is_validation = False diff --git a/rllm/parser/__init__.py b/rllm/parser/__init__.py index 116726144..968f2b51e 100644 --- a/rllm/parser/__init__.py +++ b/rllm/parser/__init__.py @@ -6,6 +6,7 @@ "DeepseekQwenChatTemplateParser", "QwenChatTemplateParser", "LlamaChatTemplateParser", + "TinkerChatTemplateParser", "ToolParser", "R1ToolParser", "QwenToolParser", @@ -20,3 +21,11 @@ def get_tool_parser(parser_name: str) -> type[ToolParser]: assert parser_name in PARSER_REGISTRY, f"Tool parser {parser_name} not found in {PARSER_REGISTRY}" return PARSER_REGISTRY[parser_name] + + +def __getattr__(name): + if name == "TinkerChatTemplateParser": + from rllm.parser.tinker_parser import TinkerChatTemplateParser + + return TinkerChatTemplateParser + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/parser/tinker_parser.py b/rllm/parser/tinker_parser.py new file mode 100644 index 000000000..0bfe5a9dd --- /dev/null +++ b/rllm/parser/tinker_parser.py @@ -0,0 +1,400 @@ +import json +import logging + +import torch + +from rllm.parser.chat_template_parser import ChatTemplateParser +from rllm.tools.tool_base import Tool, ToolCall + +logger = logging.getLogger(__name__) + + +try: + import tinker + from tinker.types import ModelInput + from tinker_cookbook.renderers.base import RenderContext, Renderer, TrainOnWhat +except ImportError as e: + raise ImportError("tinker-cookbook and tinker are required for TinkerChatTemplateParser. Install them with: pip install tinker-cookbook tinker") from e + + +def _make_render_context(idx, is_last, prev_message=None, last_user_index=-1): + """Create a RenderContext, handling version differences in tinker-cookbook.""" + try: + return RenderContext( + idx=idx, + is_last=is_last, + prev_message=prev_message, + last_user_index=last_user_index, + ) + except TypeError: + # Older tinker-cookbook without last_user_index field + return RenderContext(idx=idx, is_last=is_last, prev_message=prev_message) + + +class TinkerChatTemplateParser(ChatTemplateParser): + """ChatTemplateParser that delegates to a tinker-cookbook Renderer. + + This allows users who have tinker-cookbook installed to use any tinker + renderer through rllm's ChatTemplateParser interface, avoiding the need + to write a manual parser for each model family. + + Example:: + + from tinker_cookbook import renderers, tokenizer_utils + from rllm.parser import TinkerChatTemplateParser + + tokenizer = tokenizer_utils.get_tokenizer("Qwen/Qwen3-8B") + renderer = renderers.get_renderer("qwen3", tokenizer) + parser = TinkerChatTemplateParser(renderer) + + prompt = parser.parse(messages, add_generation_prompt=True, is_first_msg=True) + """ + + def __init__(self, renderer: Renderer) -> None: + if not isinstance(renderer, Renderer): + raise TypeError(f"Expected a tinker_cookbook Renderer, got {type(renderer)}") + self.renderer = renderer + self.tokenizer = renderer.tokenizer + self.processor = None + + # Compute generation_prompt by decoding the generation suffix tokens + ctx = _make_render_context(idx=0, is_last=True) + suffix_tokens = self.renderer._get_generation_suffix("assistant", ctx) + self.generation_prompt = self.tokenizer.decode(suffix_tokens) if suffix_tokens else "" + + self.stop_sequences = self.renderer.get_stop_sequences() + + def _convert_message(self, msg: dict) -> dict: + """Convert an rllm message dict to a tinker Message dict.""" + tinker_msg = {"role": msg["role"]} + + content = msg.get("content", "") or "" + reasoning = (msg.get("reasoning", "") or "").strip() + + # Build structured content when reasoning or images are present + if reasoning: + parts = [] + parts.append({"type": "thinking", "thinking": reasoning}) + if content: + parts.append({"type": "text", "text": content}) + tinker_msg["content"] = parts + elif isinstance(msg.get("images"), list) and msg["images"]: + parts = [] + for img in msg["images"]: + parts.append({"type": "image", "image": img}) + if content: + # Strip leading tag if present (rllm convention) + if content.startswith(""): + content = content[len("") :] + parts.append({"type": "text", "text": content}) + tinker_msg["content"] = parts + else: + tinker_msg["content"] = content + + # Convert tool_calls to tinker ToolCall format + if msg.get("tool_calls"): + from tinker_cookbook.renderers.base import ToolCall as TinkerToolCall + + tool_calls = [] + for tc in msg["tool_calls"]: + if isinstance(tc, ToolCall): + # rllm ToolCall dataclass + args = tc.arguments if isinstance(tc.arguments, str) else json.dumps(tc.arguments) + tool_calls.append( + TinkerToolCall( + function=TinkerToolCall.FunctionBody(name=tc.name, arguments=args), + ) + ) + elif isinstance(tc, dict) and "function" in tc: + func = tc["function"] + args = func.get("arguments", "{}") + if not isinstance(args, str): + args = json.dumps(args) + tool_calls.append( + TinkerToolCall( + function=TinkerToolCall.FunctionBody(name=func["name"], arguments=args), + id=tc.get("id"), + ) + ) + elif isinstance(tc, dict) and "name" in tc: + args = tc.get("arguments", "{}") + if not isinstance(args, str): + args = json.dumps(args) + tool_calls.append( + TinkerToolCall( + function=TinkerToolCall.FunctionBody(name=tc["name"], arguments=args), + id=tc.get("id"), + ) + ) + if tool_calls: + tinker_msg["tool_calls"] = tool_calls + + # Handle tool response fields + if msg["role"] == "tool": + if "tool_call_id" in msg: + tinker_msg["tool_call_id"] = msg["tool_call_id"] + if "name" in msg: + tinker_msg["name"] = msg["name"] + + return tinker_msg + + def _convert_messages(self, messages: list[dict]) -> list[dict]: + """Convert a list of rllm message dicts to tinker Message format.""" + return [self._convert_message(m) for m in messages] + + def _convert_tools(self, tools: list[Tool | dict]) -> list[dict]: + """Convert rllm tools to tinker ToolSpec format.""" + tool_specs = [] + for tool in tools: + if isinstance(tool, Tool): + # rllm Tool object - extract from json property + tool_json = tool.json + if "function" in tool_json: + func = tool_json["function"] + tool_specs.append( + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) + elif isinstance(tool, dict): + if "function" in tool: + func = tool["function"] + tool_specs.append( + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) + elif "name" in tool: + tool_specs.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + } + ) + return tool_specs + + def _render_to_tokens(self, tinker_messages: list[dict], add_bos: bool = False, add_generation_prompt: bool = False) -> list[int]: + """Render tinker messages to a flat list of token IDs.""" + + chunks = [] + + if add_bos and self.renderer._bos_tokens: + chunks.append(tinker.EncodedTextChunk(tokens=self.renderer._bos_tokens)) + + last_user_idx = max( + (i for i, m in enumerate(tinker_messages) if m["role"] == "user"), + default=-1, + ) + + for idx, msg in enumerate(tinker_messages): + ctx = _make_render_context( + idx=idx, + is_last=(idx == len(tinker_messages) - 1) and not add_generation_prompt, + prev_message=tinker_messages[idx - 1] if idx > 0 else None, + last_user_index=last_user_idx, + ) + rendered = self.renderer.render_message(msg, ctx) + if rendered.header: + chunks.append(rendered.header) + chunks.extend(x for x in rendered.output if not isinstance(x, tinker.EncodedTextChunk) or x.tokens) + + if add_generation_prompt: + suffix_ctx = _make_render_context( + idx=len(tinker_messages), + is_last=True, + prev_message=tinker_messages[-1] if tinker_messages else None, + last_user_index=last_user_idx, + ) + suffix_tokens = self.renderer._get_generation_suffix("assistant", suffix_ctx) + if suffix_tokens: + chunks.append(tinker.EncodedTextChunk(tokens=suffix_tokens)) + + # Flatten chunks to token list + tokens = [] + for chunk in chunks: + if isinstance(chunk, tinker.EncodedTextChunk): + tokens.extend(chunk.tokens) + else: + # ImageChunk or other non-token chunk - use length as placeholder + # This path is for VL models; decode will produce placeholder tokens + tokens.extend([0] * chunk.length) + + return tokens + + def _prepare_messages(self, messages: list[dict], tools: list[Tool | dict] | None = None) -> list[dict]: + """Convert rllm messages to tinker format and prepend tool context if needed. + + Args: + messages: List of rllm message dicts. + tools: Optional list of tools to include in the system prompt. + + Returns: + List of tinker-format message dicts ready for rendering. + """ + tinker_messages = self._convert_messages(messages) + + if tools: + tool_specs = self._convert_tools(tools) + if tool_specs: + try: + system_prompt = "" + if tinker_messages and tinker_messages[0]["role"] == "system": + content = tinker_messages[0]["content"] + if isinstance(content, str): + system_prompt = content + tinker_messages = tinker_messages[1:] + prefix = self.renderer.create_conversation_prefix_with_tools(tool_specs, system_prompt) + tinker_messages = prefix + tinker_messages + except NotImplementedError: + logger.warning(f"Renderer {type(self.renderer).__name__} does not support tool calling. Tools will be ignored.") + + return tinker_messages + + def build_prompt(self, messages: list[dict], tools: list[Tool | dict] | None = None) -> ModelInput: + """Build a ModelInput prompt from messages, preserving image chunks for VLM. + + Unlike parse() which decodes to a string, this returns a ModelInput directly + via the renderer's build_generation_prompt, avoiding the token->string->token + round-trip and preserving ImageChunks for vision-language models. + + Args: + messages: List of rllm message dicts. + tools: Optional list of tools to include in the prompt. + + Returns: + tinker ModelInput with generation prompt appended. + """ + tinker_messages = self._prepare_messages(messages, tools=tools) + return self.renderer.build_generation_prompt(tinker_messages) + + def parse(self, messages: list[dict], add_generation_prompt: bool = False, is_first_msg: bool = False, tools: list[Tool | dict] | None = None, **kwargs) -> str: + """Parse messages into a prompt string. + + Note: For TinkerEngine, prefer build_prompt() which returns a ModelInput + directly and preserves image chunks. This method is for compatibility with + non-Tinker rollout engines. + + Args: + messages: List of rllm message dicts. + add_generation_prompt: Whether to append the generation prompt. + is_first_msg: Whether this is the first message (adds BOS token). + tools: Optional list of tools to include in the prompt. + + Returns: + The rendered prompt string. + """ + if not messages: + return "" + + tinker_messages = self._prepare_messages(messages, tools=tools) + + tokens = self._render_to_tokens(tinker_messages, add_bos=is_first_msg, add_generation_prompt=add_generation_prompt) + result = self.tokenizer.decode(tokens, skip_special_tokens=False) + + # Tinker puts the \n separator in the next message's header, so the last + # message lacks a trailing \n. HF templates always include it. Add it to + # match HF's apply_chat_template output. + if result and not result.endswith("\n"): + result += "\n" + + return result + + def parse_completion(self, completion_ids: list[int]) -> dict[str, str | list]: + """Parse completion token IDs into structured output. + + Args: + completion_ids: List of token IDs from model generation. + + Returns: + Dict with 'content', 'reasoning', and 'tool_calls' keys. + """ + parsed_msg, _success = self.renderer.parse_response(completion_ids) + + content = "" + reasoning = "" + tool_calls = [] + + msg_content = parsed_msg.get("content", "") + if isinstance(msg_content, str): + content = msg_content + elif isinstance(msg_content, list): + text_parts = [] + thinking_parts = [] + for part in msg_content: + if part["type"] == "text": + text_parts.append(part["text"]) + elif part["type"] == "thinking": + thinking_parts.append(part["thinking"]) + content = "".join(text_parts) + reasoning = "".join(thinking_parts) + + # Convert tinker ToolCall objects to rllm ToolCall dataclass + if parsed_msg.get("tool_calls"): + for tc in parsed_msg["tool_calls"]: + try: + args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + args = tc.function.arguments + tool_calls.append(ToolCall(name=tc.function.name, arguments=args)) + + return { + "content": content.strip(), + "reasoning": reasoning.strip(), + "tool_calls": tool_calls, + } + + def tokenize_and_mask(self, messages): + """Convert messages to token IDs with loss masks using tinker's supervised example builder. + + Returns: + Tuple of (prompt_ids, response_ids, response_mask) as torch tensors. + """ + tinker_messages = self._convert_messages(messages) + model_input, weights = self.renderer.build_supervised_example(tinker_messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE) + + all_tokens = model_input.to_ints() + weights_list = weights.tolist() + + # Split at first non-zero weight + boundary = next((i for i, w in enumerate(weights_list) if w > 0), len(weights_list)) + + prompt_ids = torch.tensor(all_tokens[:boundary], dtype=torch.long) + response_ids = torch.tensor(all_tokens[boundary:], dtype=torch.long) + response_mask = weights[boundary:].long() + + return prompt_ids, response_ids, response_mask + + def tokenize_and_mask_cumulative(self, messages): + """Convert multi-turn messages to token IDs with cumulative loss masks. + + Returns: + Tuple of (prompt_ids, response_ids, response_mask) as torch tensors. + """ + tinker_messages = self._convert_messages(messages) + model_input, weights = self.renderer.build_supervised_example(tinker_messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES) + + all_tokens = model_input.to_ints() + weights_list = weights.tolist() + + # Split at first non-zero weight + boundary = next((i for i, w in enumerate(weights_list) if w > 0), len(weights_list)) + + prompt_ids = torch.tensor(all_tokens[:boundary], dtype=torch.long) + response_ids = torch.tensor(all_tokens[boundary:], dtype=torch.long) + response_mask = weights[boundary:].long() + + return prompt_ids, response_ids, response_mask + + def verify_equivalence(self, messages, verbose=True): + """Tinker renderers handle token-level correctness by design. + + NOTE(listar2000): the `verify_equivalence` test from parent does not make too much sense. + Instead of checking equivalence with HF templates, it check single versus multiple message parsing. + So it makes sense for the tinker parser to not pass this test. We simply return True here. + """ + return True diff --git a/rllm/parser/utils.py b/rllm/parser/utils.py index e255b04ba..61f52d40e 100644 --- a/rllm/parser/utils.py +++ b/rllm/parser/utils.py @@ -6,3 +6,14 @@ {"role": "user", "content": "What about Java?"}, {"role": "assistant", "content": "Let me search for Java information.", "tool_calls": [{"function": {"name": "search", "arguments": '{"query": "Java programming"}'}}]}, ] + +# Simple multi-turn messages for verify_equivalence tests. +# Ends with a user message (representing the prompt before model generation) +# to avoid HF template quirks like Qwen3's tag insertion on the last +# assistant message after the last user query. +SIMPLE_TEST_MESSAGES = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you! How can I help you today?"}, + {"role": "user", "content": "What is the capital of France?"}, +] diff --git a/rllm/trainer/config/tinker_rl_trainer.yaml b/rllm/trainer/config/tinker_rl_trainer.yaml index 95630a37c..8862068a6 100644 --- a/rllm/trainer/config/tinker_rl_trainer.yaml +++ b/rllm/trainer/config/tinker_rl_trainer.yaml @@ -69,7 +69,6 @@ rollout_engine: reasoning_effort: "medium" accumulate_reasoning: false disable_thinking: false - bypass_render_with_parser: false renderer_name: null # Override renderer name (null = auto-detect from model) # Data Configuration diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 51385bcd7..673132e41 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -23,9 +23,9 @@ from rllm.agents.agent import Episode from rllm.data import Dataset +from rllm.engine.rollout import RolloutEngine, TinkerEngine from rllm.experimental.common import AlgorithmConfig, simple_timer from rllm.experimental.protocol import BackendProtocol -from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.trainer.tinker.tinker_metrics_utils import ( print_metrics_table, update_training_metrics, @@ -113,6 +113,8 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: Args: **kwargs: Additional arguments, including the various configurations + - strip_thinking_from_history: Whether to strip thinking from history (default = true) + - renderer_name: Name of the renderer to use (default = auto-detect from model) Returns: TinkerEngine: The initialized rollout engine. diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index f758b7ce1..b016535c3 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -11,9 +11,9 @@ from tinker_cookbook.supervised.common import create_rightshifted_model_input_and_leftshifted_targets from rllm.agents.agent import Trajectory, TrajectoryGroup +from rllm.engine.rollout.tinker_engine import _flat_token_input_length, _flat_token_input_to_model_input +from rllm.engine.rollout.types import TinkerTokenInput from rllm.experimental.common import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups -from rllm.experimental.rollout.tinker_engine import _flat_token_input_length, _flat_token_input_to_model_input -from rllm.experimental.rollout.types import TinkerTokenInput def _is_prefix(seq1: TinkerTokenInput, seq2: TinkerTokenInput) -> bool: diff --git a/tests/parser/conftest.py b/tests/parser/conftest.py new file mode 100644 index 000000000..8ab875f00 --- /dev/null +++ b/tests/parser/conftest.py @@ -0,0 +1,54 @@ +"""Parser tests require real packages (transformers, pydantic, torch, etc.). + +The root conftest.py stubs out heavy optional dependencies for lightweight unit +tests. This conftest removes the specific stubs so parser integration tests can +use real packages. +""" + +import sys +import types + +# These are the exact modules stubbed by root conftest.py _STUB_MODULES list, +# plus the additional stubs it creates for sub-modules and fake classes. +_ROOT_STUB_MODULES = [ + "numpy", + "httpx", + "transformers", + "datasets", + "ray", + "pandas", + "polars", + "sympy", + "pylatexenc", + "antlr4", + "antlr4_python3_runtime", + "mcp", + "eval_protocol", + "hydra", + "fastapi", + "uvicorn", + "tqdm", + "yaml", + "pydantic", + "wrapt", + "asgiref", + "wandb", + "codetiming", + "click", + # Also stubbed explicitly by root conftest + "torch", + "PIL", + "openai", +] + +# Remove stub modules and any sub-modules created by root conftest +_to_remove = [] +for name in list(sys.modules): + base = name.split(".")[0] + if base in _ROOT_STUB_MODULES: + mod = sys.modules[name] + if isinstance(mod, types.ModuleType) and not hasattr(mod, "__file__"): + _to_remove.append(name) + +for name in _to_remove: + del sys.modules[name] diff --git a/tests/parser/test_chat_parser.py b/tests/parser/test_chat_parser.py index d45c7fdd8..4bac5428f 100644 --- a/tests/parser/test_chat_parser.py +++ b/tests/parser/test_chat_parser.py @@ -73,7 +73,7 @@ def test_parser_with_disable_thinking(): parser = QwenChatTemplateParser(tokenizer, disable_thinking=True) # Verify that thinking is disabled in the generation prompt - assert "\\n\\n\\n\\n" in parser.assistant_token + assert "\n\n\n\n" in parser.assistant_token # Test equivalence check assert parser.verify_equivalence(PARSER_TEST_MESSAGES) diff --git a/tests/parser/test_tinker_parser.py b/tests/parser/test_tinker_parser.py new file mode 100644 index 000000000..54dbedea0 --- /dev/null +++ b/tests/parser/test_tinker_parser.py @@ -0,0 +1,224 @@ +import sys +from unittest.mock import patch + +import pytest +from tinker_cookbook import renderers +from transformers import AutoTokenizer + +from rllm.parser import QwenChatTemplateParser +from rllm.parser.tinker_parser import TinkerChatTemplateParser +from rllm.parser.utils import SIMPLE_TEST_MESSAGES + + +@pytest.fixture +def qwen_tokenizer(): + return AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") + + +@pytest.fixture +def qwen_renderer(qwen_tokenizer): + return renderers.get_renderer("qwen3", qwen_tokenizer) + + +@pytest.fixture +def qwen_tinker_parser(qwen_renderer): + return TinkerChatTemplateParser(qwen_renderer) + + +def test_tinker_parser_init(qwen_tinker_parser): + """Verify that constructor sets up generation_prompt and stop_sequences.""" + assert qwen_tinker_parser.generation_prompt + assert isinstance(qwen_tinker_parser.generation_prompt, str) + assert qwen_tinker_parser.stop_sequences is not None + assert qwen_tinker_parser.tokenizer is not None + assert qwen_tinker_parser.renderer is not None + + +def test_tinker_parser_init_bad_renderer(): + """Verify TypeError when passing a non-renderer object.""" + with pytest.raises(TypeError, match="Expected a tinker_cookbook Renderer"): + TinkerChatTemplateParser("not a renderer") + + +def test_tinker_parser_parse(qwen_tinker_parser): + """Verify parse() returns a valid non-empty string.""" + result = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=True, is_first_msg=True) + assert isinstance(result, str) + assert len(result) > 0 + + +def test_tinker_parser_parse_empty(): + """Verify parse([]) returns empty string.""" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") + renderer = renderers.get_renderer("qwen3", tokenizer) + parser = TinkerChatTemplateParser(renderer) + assert parser.parse([]) == "" + + +def test_tinker_parser_parse_generation_prompt(qwen_tinker_parser): + """Verify that generation prompt is appended when requested.""" + with_prompt = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=True, is_first_msg=True) + without_prompt = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=False, is_first_msg=True) + # The version with generation prompt should be longer + assert len(with_prompt) > len(without_prompt) + + +def test_tinker_parser_parse_is_first_msg(qwen_tinker_parser): + """Verify is_first_msg controls BOS token inclusion.""" + with_bos = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, is_first_msg=True) + without_bos = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, is_first_msg=False) + # With BOS should be at least as long as without + assert len(with_bos) >= len(without_bos) + + +def test_tinker_parser_parse_with_reasoning(qwen_tinker_parser): + """Verify that reasoning is included when accumulate_reasoning=True.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there", "reasoning": "The user greeted me"}, + ] + with_reasoning = qwen_tinker_parser.parse(messages, accumulate_reasoning=True, is_first_msg=True) + without_reasoning = qwen_tinker_parser.parse(messages, accumulate_reasoning=False, is_first_msg=True) + assert "think" in with_reasoning or len(with_reasoning) > len(without_reasoning) + + +def test_tinker_parser_parse_completion(qwen_tinker_parser, qwen_tokenizer): + """Verify parse_completion returns correct structure.""" + # Encode a proper assistant response with thinking + end token. + # The renderer expects tokens as if produced by the model during generation, + # which means they must end with the stop sequence (<|im_end|> for Qwen3). + text = "\nLet me think about this.\n\n\nHello, how can I help?<|im_end|>" + token_ids = qwen_tokenizer.encode(text, add_special_tokens=False) + + result = qwen_tinker_parser.parse_completion(token_ids) + + assert isinstance(result, dict) + assert "content" in result + assert "reasoning" in result + assert "tool_calls" in result + assert isinstance(result["tool_calls"], list) + # The thinking should be extracted as reasoning + assert result["reasoning"] + assert "think" in result["reasoning"].lower() + assert "Hello" in result["content"] + + +def test_tinker_parser_tokenize_and_mask(qwen_tinker_parser): + """Verify tokenize_and_mask returns correct tensor shapes and mask values.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + prompt_ids, response_ids, response_mask = qwen_tinker_parser.tokenize_and_mask(messages) + + assert prompt_ids.dim() == 1 + assert response_ids.dim() == 1 + assert response_mask.dim() == 1 + assert len(response_ids) == len(response_mask) + assert len(prompt_ids) > 0 + assert len(response_ids) > 0 + # Response mask should have non-zero values + assert response_mask.sum() > 0 + + +def test_tinker_parser_tokenize_and_mask_cumulative(qwen_tinker_parser): + """Verify tokenize_and_mask_cumulative returns correct tensor shapes.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "6"}, + ] + prompt_ids, response_ids, response_mask = qwen_tinker_parser.tokenize_and_mask_cumulative(messages) + + assert prompt_ids.dim() == 1 + assert response_ids.dim() == 1 + assert response_mask.dim() == 1 + assert len(response_ids) == len(response_mask) + assert len(prompt_ids) > 0 + assert len(response_ids) > 0 + # Both assistant responses should be masked + assert response_mask.sum() > 0 + # Should have some zero-masked tokens (user message between assistants) + assert (response_mask == 0).any() + + +def test_tinker_parser_verify_equivalence(qwen_tinker_parser): + """Tinker parser should always return True for verify_equivalence.""" + assert qwen_tinker_parser.verify_equivalence(SIMPLE_TEST_MESSAGES) is True + + +def test_tinker_parser_matches_manual_qwen(qwen_tokenizer): + """Compare TinkerChatTemplateParser output with QwenChatTemplateParser for simple messages.""" + renderer = renderers.get_renderer("qwen3", qwen_tokenizer) + tinker_parser = TinkerChatTemplateParser(renderer) + manual_parser = QwenChatTemplateParser(qwen_tokenizer) + + # Simple messages without tool calls (avoid tool call format differences) + simple_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + tinker_result = tinker_parser.parse(simple_messages, add_generation_prompt=False, is_first_msg=True) + manual_result = manual_parser.parse(simple_messages, add_generation_prompt=False, is_first_msg=True) + + # Tokenize both and compare token sequences (more robust than string comparison + # because decode round-trip may differ in whitespace/special token rendering). + # Strip trailing whitespace since HF templates add \n after <|im_end|> but + # tinker's token-level rendering does not. + tinker_tokens = qwen_tokenizer.encode(tinker_result.rstrip(), add_special_tokens=False) + manual_tokens = qwen_tokenizer.encode(manual_result.rstrip(), add_special_tokens=False) + assert tinker_tokens == manual_tokens + + +def test_tinker_parser_message_conversion(qwen_tinker_parser): + """Test that message conversion handles various message formats.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + { + "role": "assistant", + "content": "Let me search.", + "tool_calls": [{"function": {"name": "search", "arguments": '{"q": "test"}'}}], + }, + ] + converted = qwen_tinker_parser._convert_messages(messages) + assert len(converted) == 3 + assert converted[0]["role"] == "system" + assert converted[1]["role"] == "user" + assert converted[2]["role"] == "assistant" + + +def test_import_error_without_tinker(): + """Verify helpful ImportError when tinker-cookbook is not installed.""" + # The module-level import in tinker_parser.py raises ImportError if tinker-cookbook + # is not installed. Since the module is already imported, we verify the error message + # by checking the module-level try/except pattern exists. + import importlib + + saved_modules = {} + modules_to_remove = [key for key in sys.modules if key.startswith(("tinker_cookbook", "tinker"))] + # Also remove the cached tinker_parser module so it can be re-imported + if "rllm.parser.tinker_parser" in sys.modules: + saved_modules["rllm.parser.tinker_parser"] = sys.modules.pop("rllm.parser.tinker_parser") + for key in modules_to_remove: + saved_modules[key] = sys.modules.pop(key) + + try: + with patch.dict( + sys.modules, + { + "tinker_cookbook": None, + "tinker_cookbook.renderers": None, + "tinker_cookbook.renderers.base": None, + "tinker": None, + }, + ): + with pytest.raises(ImportError, match="tinker-cookbook and tinker are required"): + importlib.import_module("rllm.parser.tinker_parser") + finally: + sys.modules.update(saved_modules) From 9a9cb761a3d77302ebd10e4719ea16d28465e31d Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:08:31 -0500 Subject: [PATCH 05/21] dump changes to rollout_engine into main file --- rllm/engine/rollout/rollout_engine.py | 44 ++++++- rllm/experimental/rollout/rollout_engine.py | 127 -------------------- 2 files changed, 43 insertions(+), 128 deletions(-) delete mode 100644 rllm/experimental/rollout/rollout_engine.py diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index 74ccd8b73..d5dd8b321 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -1,3 +1,4 @@ +import asyncio from dataclasses import dataclass from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput @@ -16,9 +17,11 @@ class ModelOutput: multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None + weight_version: int | None = None # policy version at time of generation def to_dict(self): return { @@ -34,6 +37,7 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, + "weight_version": self.weight_version, } @classmethod @@ -51,6 +55,7 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), + weight_version=data.get("weight_version"), ) @@ -60,7 +65,44 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - pass + # Gate mechanism for pausing model calls during weight sync + self._gate: asyncio.Event = asyncio.Event() + self._gate.set() # open by default + self._active_calls: int = 0 + self._drained_event: asyncio.Event = asyncio.Event() + self._drained_event.set() # initially drained (no active calls) + self.weight_version: int = 0 + + # --- Gate mechanism --- + + def close_gate(self) -> None: + """Close the gate. New model calls will block at wait_for_gate().""" + self._gate.clear() + + def open_gate(self) -> None: + """Open the gate, releasing any blocked model calls.""" + self._gate.set() + + def on_model_call_complete(self) -> None: + """Unregister active call. Engines must call this at the END of + get_model_response() (in a finally block).""" + self._active_calls -= 1 + if self._active_calls <= 0: + self._active_calls = 0 + self._drained_event.set() + + async def wait_for_gate(self) -> None: + """Wait until gate is open, then register as active call. + Engines must call this at the START of get_model_response().""" + await self._gate.wait() + self._active_calls += 1 + self._drained_event.clear() + + async def wait_for_drain(self) -> None: + """Wait until all active model calls complete. Used during weight sync.""" + await self._drained_event.wait() + + # --- Model response --- async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: raise NotImplementedError("get_model_response is not implemented") diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py deleted file mode 100644 index 29ada4db9..000000000 --- a/rllm/experimental/rollout/rollout_engine.py +++ /dev/null @@ -1,127 +0,0 @@ -import asyncio -from dataclasses import dataclass - -from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput -from rllm.parser import ChatTemplateParser -from rllm.tools.tool_base import ToolCall - - -@dataclass -class ModelOutput: - text: str | None = None - content: str | None = None - reasoning: str | None = None - tool_calls: list[ToolCall] | None = None - prompt_ids: TokenInput | None = None - completion_ids: list[int] | None = None - multi_modal_inputs: dict[str, list] | None = None - logprobs: list[float] | None = None # completion logprobs - prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids - routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) - prompt_length: int = 0 - completion_length: int = 0 - finish_reason: str | None = None - weight_version: int | None = None # policy version at time of generation - - def to_dict(self): - return { - "text": self.text, - "content": self.content, - "reasoning": self.reasoning, - "tool_calls": [tool_call.to_dict() for tool_call in self.tool_calls] if self.tool_calls else [], - "prompt_ids": self.prompt_ids, - "completion_ids": self.completion_ids, - "multi_modal_inputs": self.multi_modal_inputs, - "logprobs": self.logprobs, - "prompt_logprobs": self.prompt_logprobs, - "prompt_length": self.prompt_length, - "completion_length": self.completion_length, - "finish_reason": self.finish_reason, - "weight_version": self.weight_version, - } - - @classmethod - def from_dict(cls, data: dict): - return cls( - text=data.get("text"), - content=data.get("content"), - reasoning=data.get("reasoning"), - tool_calls=[ToolCall(**tool_call) for tool_call in data.get("tool_calls", [])] if data.get("tool_calls") else None, - prompt_ids=data.get("prompt_ids"), - completion_ids=data.get("completion_ids"), - multi_modal_inputs=data.get("multi_modal_inputs"), - logprobs=data.get("logprobs"), - prompt_logprobs=data.get("prompt_logprobs"), - prompt_length=data.get("prompt_length", 0), - completion_length=data.get("completion_length", 0), - finish_reason=data.get("finish_reason"), - weight_version=data.get("weight_version"), - ) - - -class RolloutEngine: - chat_parser: ChatTemplateParser | None = None - tokenizer: Tokenizer | None = None - is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks - - def __init__(self, *args, **kwargs): - # Gate mechanism for pausing model calls during weight sync - self._gate: asyncio.Event = asyncio.Event() - self._gate.set() # open by default - self._active_calls: int = 0 - self._drained_event: asyncio.Event = asyncio.Event() - self._drained_event.set() # initially drained (no active calls) - self.weight_version: int = 0 - - # --- Gate mechanism --- - - def close_gate(self) -> None: - """Close the gate. New model calls will block at wait_for_gate().""" - self._gate.clear() - - def open_gate(self) -> None: - """Open the gate, releasing any blocked model calls.""" - self._gate.set() - - async def wait_for_gate(self) -> None: - """Wait until gate is open, then register as active call. - Engines must call this at the START of get_model_response().""" - await self._gate.wait() - self._active_calls += 1 - self._drained_event.clear() - - def on_model_call_complete(self) -> None: - """Unregister active call. Engines must call this at the END of - get_model_response() (in a finally block).""" - self._active_calls -= 1 - if self._active_calls <= 0: - self._active_calls = 0 - self._drained_event.set() - - async def wait_for_drain(self) -> None: - """Wait until all active model calls complete. Used during weight sync.""" - await self._drained_event.wait() - - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") - - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: - """ - Assemble model output from a token output. - """ - raise NotImplementedError("assemble_model_output is not implemented") - - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TokenOutput: - """Obtain the token output from the given token input.""" - raise NotImplementedError("get_token_output_from_token_input is not implemented") - - async def wake_up(self): - pass - - async def sleep(self): - pass - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False From 18ca0f42edab3e5b797403d80dfa4871c4f7a607 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:13:27 -0500 Subject: [PATCH 06/21] refactor base rollout engine class to standardize gating behaviors --- rllm/engine/rollout/openai_engine.py | 2 +- rllm/engine/rollout/rollout_engine.py | 14 ++++--- rllm/engine/rollout/tinker_engine.py | 2 +- rllm/engine/rollout/verl_engine.py | 2 +- rllm/experimental/unified_trainer.py | 59 +++++++++++---------------- 5 files changed, 36 insertions(+), 43 deletions(-) diff --git a/rllm/engine/rollout/openai_engine.py b/rllm/engine/rollout/openai_engine.py index 60c130505..78d683f4e 100644 --- a/rllm/engine/rollout/openai_engine.py +++ b/rllm/engine/rollout/openai_engine.py @@ -225,7 +225,7 @@ async def completion(self, prompt: str | list[int], **kwargs) -> ModelOutput: print(f"Error: {e}, retrying...") await asyncio.sleep(1) - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: if self._use_chat_completions: accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) if accumulate_reasoning: diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index d5dd8b321..755239968 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -84,16 +84,14 @@ def open_gate(self) -> None: self._gate.set() def on_model_call_complete(self) -> None: - """Unregister active call. Engines must call this at the END of - get_model_response() (in a finally block).""" + """Unregister active call. Engines will call this at the END of get_model_response().""" self._active_calls -= 1 if self._active_calls <= 0: self._active_calls = 0 self._drained_event.set() async def wait_for_gate(self) -> None: - """Wait until gate is open, then register as active call. - Engines must call this at the START of get_model_response().""" + """Wait until gate is open, then register as active call. Engines will call this at the START of get_model_response().""" await self._gate.wait() self._active_calls += 1 self._drained_event.clear() @@ -103,9 +101,15 @@ async def wait_for_drain(self) -> None: await self._drained_event.wait() # --- Model response --- + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") + await self.wait_for_gate() + try: + return await self._get_model_response(messages, **kwargs) + finally: + self.on_model_call_complete() def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: """ diff --git a/rllm/engine/rollout/tinker_engine.py b/rllm/engine/rollout/tinker_engine.py index 12de041e2..f70cbec39 100644 --- a/rllm/engine/rollout/tinker_engine.py +++ b/rllm/engine/rollout/tinker_engine.py @@ -236,7 +236,7 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp ) @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: """ Generate model response for a given set of messages. diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py index 98ddc5b13..69b125de0 100644 --- a/rllm/engine/rollout/verl_engine.py +++ b/rllm/engine/rollout/verl_engine.py @@ -71,7 +71,7 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa return token_output @override - async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: # these go to the parser tools = kwargs.pop("tools", []) accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 5306101af..71e298dd1 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -801,41 +801,30 @@ def __init__( backend: Literal["verl", "tinker", "fireworks"] = "verl", **kwargs, ): - if backend == "verl": - from rllm.experimental.verl.verl_launcher import VerlTrainerLauncher - - self.launcher = VerlTrainerLauncher( - config=config, - workflow_class=workflow_class, - train_dataset=train_dataset, - val_dataset=val_dataset, - workflow_args=workflow_args, - **kwargs, - ) - elif backend == "tinker": - from rllm.trainer.tinker.tinker_launcher import TinkerTrainerLauncher - - self.launcher = TinkerTrainerLauncher( - config=config, - workflow_class=workflow_class, - train_dataset=train_dataset, - val_dataset=val_dataset, - workflow_args=workflow_args, - **kwargs, - ) - elif backend == "fireworks": - from rllm.trainer.fireworks.fireworks_launcher import ( - FireworksTrainerLauncher, - ) - - self.launcher = FireworksTrainerLauncher( - config=config, - workflow_class=workflow_class, - train_dataset=train_dataset, - val_dataset=val_dataset, - workflow_args=workflow_args, - **kwargs, - ) + match backend: + case "verl": + from rllm.experimental.verl.verl_launcher import VerlTrainerLauncher + + launcher_cls = VerlTrainerLauncher + case "tinker": + from rllm.trainer.tinker.tinker_launcher import TinkerTrainerLauncher + + launcher_cls = TinkerTrainerLauncher + case "fireworks": + from rllm.trainer.fireworks.fireworks_launcher import FireworksTrainerLauncher + + launcher_cls = FireworksTrainerLauncher + case _: + raise ValueError(f"Unsupported backend: {backend}, must be one of ['verl', 'tinker', 'fireworks']") + + self.launcher = launcher_cls( + config=config, + workflow_class=workflow_class, + train_dataset=train_dataset, + val_dataset=val_dataset, + workflow_args=workflow_args, + **kwargs, + ) def train(self): self.launcher.train() From 764d0e10f232dd9dd1c6b7136f0140b7ea7de452 Mon Sep 17 00:00:00 2001 From: listar2000 <35262801+listar2000@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:54:46 -0500 Subject: [PATCH 07/21] make tinker backend fully compatible --- rllm/agents/agent.py | 37 +++++++++++++++---- rllm/engine/rollout/rollout_engine.py | 4 +- .../engine/unified_workflow_engine.py | 3 +- rllm/experimental/unified_trainer.py | 2 +- rllm/trainer/tinker/tinker_backend.py | 6 +++ 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index 14664a6b5..1b2ad9c07 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -2,6 +2,7 @@ import uuid from abc import ABC, abstractmethod +from copy import deepcopy from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Any @@ -41,6 +42,7 @@ class Step: prompt_ids: list[int] | list[Any] = field(default_factory=list) response_ids: list[int] = field(default_factory=list) logprobs: list[float] = field(default_factory=list) + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) chat_completions: list[dict[str, str]] = field(default_factory=list) @@ -60,7 +62,13 @@ class Step: # TODO: potentially rename this as "advantages" so its clearer that it allows a generic list. advantage: list[float] | float | None = None + # weight version at time of generation (for async training staleness tracking) + weight_version: int | None = None + def __post_init__(self): + self.chat_completions = deepcopy(self.chat_completions) + self.info = deepcopy(self.info) + if self.model_output is None: return # backfill fields like prompt_ids, response_ids, logprobs, etc. @@ -70,22 +78,34 @@ def __post_init__(self): self.response_ids = self.model_output.completion_ids if len(self.logprobs) == 0 and self.model_output.logprobs is not None: self.logprobs = self.model_output.logprobs - - # check that the token ids are filled - # TODO(listar2000): this might cause compatibility issue. Double check if we should make these assertions. - # assert len(self.prompt_ids) > 0, "prompt_ids is empty" - # assert len(self.response_ids) > 0, "response_ids is empty" + if self.routing_matrices is None and getattr(self.model_output, "routing_matrices", None) is not None: + self.routing_matrices = self.model_output.routing_matrices + if self.weight_version is None and hasattr(self.model_output, "weight_version"): + self.weight_version = self.model_output.weight_version # check that the lengths would match up if len(self.logprobs) > 0: assert len(self.response_ids) == len(self.logprobs), f"length mismatch between response_ids and logprobs, got {len(self.response_ids)}, {len(self.logprobs)}" def to_dict(self) -> dict: + from rllm.tools.tool_base import ToolCall, ToolOutput + + # Helper function to recursively convert ToolCall and ToolOutput objects to dicts + def _serialize_value(value): + if isinstance(value, ToolCall | ToolOutput): + return value.to_dict() + elif isinstance(value, list): + return [_serialize_value(item) for item in value] + elif isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + else: + return value + return { "prompt_ids": self.prompt_ids, "response_ids": self.response_ids, "logprobs": self.logprobs, - "chat_completions": self.chat_completions, + "chat_completions": _serialize_value(self.chat_completions), "observation": self.observation, "thought": self.thought, "action": self.action.action if isinstance(self.action, Action) else self.action, @@ -96,6 +116,7 @@ def to_dict(self) -> dict: "done": self.done, "mc_return": self.mc_return, "advantage": self.advantage, + "weight_version": self.weight_version, } @classmethod @@ -116,7 +137,8 @@ def from_dict(cls, data: dict) -> Step: reward=data["reward"], done=data["done"], mc_return=data["mc_return"], - advantage=data["advantage"], + advantage=data.get("advantage", 0.0), + weight_version=data.get("weight_version"), ) @classmethod @@ -125,6 +147,7 @@ def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | Non prompt_ids=model_output.prompt_ids or [], response_ids=model_output.completion_ids or [], logprobs=model_output.logprobs or [], + routing_matrices=getattr(model_output, "routing_matrices", None), chat_completions=(messages or []) + [{"role": "assistant", "content": model_output.content, "reasoning": model_output.reasoning}], thought=model_output.reasoning or "", action=action, diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index 755239968..c7cf14ebf 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -107,7 +107,9 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: await self.wait_for_gate() try: - return await self._get_model_response(messages, **kwargs) + result = await self._get_model_response(messages, **kwargs) + result.weight_version = self.weight_version + return result finally: self.on_model_call_complete() diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 5b3ceab32..5084f5144 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -16,6 +16,7 @@ # Avoid hard dependency on verl at import time; only for typing if TYPE_CHECKING: + from omegaconf import DictConfig from verl import DataProto from rllm.utils.episode_logger import EpisodeLogger @@ -29,7 +30,7 @@ def __init__( workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, - config=None, + config: DictConfig | None = None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 71e298dd1..4f24e2be4 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -382,7 +382,7 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts # ========================================================================= - # Concurrent (async) training methods + # Fully-asynchronous training pipeline # ========================================================================= async def _fit_fully_async(self, trainer_state: TrainerState) -> None: diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index a4652af5c..e0598b9dc 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -399,6 +399,9 @@ async def on_train_start(self, trainer_state: TrainerState) -> None: # Initialize training client and load checkpoint start_batch, self.sampling_client = await self.policy_trainer.initialize_async(resume_from_checkpoint=True) + # Propagate sampling_client to rollout engine so it can make inference calls + self.rollout_engine.set_sampling_client(self.sampling_client) + # Update trainer state with the start batch from checkpoint trainer_state.global_step = start_batch @@ -421,6 +424,9 @@ async def on_policy_updated(self, trainer_state: TrainerState) -> None: do_save = save_freq > 0 and global_step % save_freq == 0 self.sampling_client = await self.policy_trainer.save_checkpoint_and_get_sampling_client(global_step, kind="both", do_save=do_save) + # Propagate updated sampling_client to rollout engine for async weight sync + self.rollout_engine.set_sampling_client(self.sampling_client) + async def on_batch_end(self, trainer_state: TrainerState) -> None: """Called at the end of each batch. From 1da008503d2afd58dd60fa965ccd644b6c7aac7b Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Mon, 30 Mar 2026 19:45:41 -0700 Subject: [PATCH 08/21] merge Kyle's fork --- pyproject.toml | 2 +- rllm/agents/agent.py | 1 + rllm/engine/rollout/openai_engine.py | 68 ++++- rllm/engine/rollout/verl_engine.py | 3 +- rllm/experimental/buffer.py | 277 ++++++++++++++++++ rllm/experimental/common/config.py | 28 +- rllm/experimental/config/rllm/base.yaml | 22 +- .../engine/unified_workflow_engine.py | 45 --- rllm/experimental/episode_buffer.py | 147 ---------- rllm/experimental/metrics.py | 122 ++++++++ rllm/experimental/protocol.py | 24 -- rllm/experimental/sync_coordinator.py | 10 +- rllm/experimental/unified_trainer.py | 249 +++++++--------- rllm/parser/chat_template_parser.py | 2 +- rllm/tools/tool_base.py | 31 ++ rllm/trainer/tinker/tinker_backend.py | 25 +- rllm/trainer/tinker/transform.py | 46 ++- rllm/workflows/workflow.py | 4 +- 18 files changed, 685 insertions(+), 421 deletions(-) create mode 100644 rllm/experimental/buffer.py delete mode 100644 rllm/experimental/episode_buffer.py create mode 100644 rllm/experimental/metrics.py diff --git a/pyproject.toml b/pyproject.toml index 666a945a5..2bdcc76e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,7 +178,7 @@ where = ["."] include = ["rllm", "rllm.*"] [tool.setuptools.package-data] -rllm = ["trainer/config/*.yaml"] +rllm = ["trainer/config/*.yaml", "experimental/config/**/*.yaml"] [tool.mypy] plugins = ["pydantic.mypy"] diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index 1b2ad9c07..393039e04 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -319,6 +319,7 @@ class TrajectoryGroup: trajectories: list[Trajectory] group_id: str = "" metadata: list[dict] = field(default_factory=list) + weight_version: int = 0 @cached_property def group_role(self) -> str: diff --git a/rllm/engine/rollout/openai_engine.py b/rllm/engine/rollout/openai_engine.py index 78d683f4e..ec95588b3 100644 --- a/rllm/engine/rollout/openai_engine.py +++ b/rllm/engine/rollout/openai_engine.py @@ -3,6 +3,7 @@ import logging import os from io import BytesIO +import json import openai from PIL import Image @@ -10,7 +11,7 @@ from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START from rllm.parser import ChatTemplateParser -from rllm.tools.tool_base import Tool +from rllm.tools.tool_base import Tool, ToolCall, ToolOutput from rllm.workflows import TerminationEvent, TerminationReason @@ -80,6 +81,57 @@ def _prepare_max_tokens_param(self, sampling_params: dict, prompt_length: int = return {"max_tokens": max_tokens} + def _convert_openai_to_tool_calls(self, tool_calls: list[dict] | None) -> list[ToolCall]: + """Convert OpenAI tool calls to internal ToolCall objects.""" + if not tool_calls: + return [] + processed_tool_calls: list[ToolCall] = [] + for tool_call in tool_calls: + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + print(f"Error parsing tool call: {tool_call.function.arguments}, error: {e}") + continue + processed_tool_calls.append( + ToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=arguments, + ) + ) + return processed_tool_calls + + def _convert_tool_calls_to_openai(self, tool_calls: list[ToolCall] | None) -> list[dict] | None: + """Convert internal ToolCall objects to OpenAI format using base class method.""" + if not tool_calls: + return None + return [tool_call.to_openai_format() if isinstance(tool_call, ToolCall) else tool_call for tool_call in tool_calls] + + def _convert_tool_outputs_to_openai(self, tool_outputs: list[ToolOutput] | None) -> list[dict] | None: + """Convert internal ToolOutput objects to OpenAI format using base class method.""" + if not tool_outputs: + return None + return [tool_output.to_openai_format() if isinstance(tool_output, ToolOutput) else tool_output for tool_output in tool_outputs] + + def _prepare_messages_for_openai(self, messages: list[dict]) -> list[dict]: + """Convert messages from internal format to OpenAI format.""" + openai_messages = [] + for msg in messages: + role = msg.get("role") + if role == "assistant": + openai_msg = {"role": "assistant", "content": msg.get("content")} + if "tool_calls" in msg and msg["tool_calls"]: + openai_msg["tool_calls"] = self._convert_tool_calls_to_openai(msg["tool_calls"]) + openai_messages.append(openai_msg) + elif role == "tool": + assert "tool_outputs" in msg, "Tool message must contain tool_outputs" + tool_msgs = self._convert_tool_outputs_to_openai(msg["tool_outputs"]) + if tool_msgs: + openai_messages.extend(tool_msgs) + else: + openai_messages.append(msg) + return openai_messages + async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: kwargs.pop("application_id", None) kwargs.pop("validate", None) @@ -90,16 +142,22 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: sampling_params.update(kwargs) create_params = self._prepare_max_tokens_param(sampling_params) - converted_messages = self._convert_messages_to_openai_format(messages) + sampling_params.update(create_params) + + tools = sampling_params.pop("tools", self.tools) + if tools: + tools = [tool.json if isinstance(tool, Tool) else tool for tool in tools] + + # Convert messages from to OpenAI format + openai_messages = self._prepare_messages_for_openai(messages) retries = self.api_retries while retries > 0: try: - response = await self.client.chat.completions.create(model=self.model, messages=converted_messages, timeout=3600, **create_params, **sampling_params) - + response = await self.client.chat.completions.create(model=self.model, messages=openai_messages, tools=tools, timeout=3600, **sampling_params) content = response.choices[0].message.content reasoning = response.choices[0].message.reasoning if hasattr(response.choices[0].message, "reasoning") and isinstance(response.choices[0].message.reasoning, str) else "" - tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") and isinstance(response.choices[0].message.tool_calls, list) else [] + tool_calls = self._convert_openai_to_tool_calls(response.choices[0].message.tool_calls) # Build text with reasoning if available, otherwise use content if reasoning: diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py index 69b125de0..969f73535 100644 --- a/rllm/engine/rollout/verl_engine.py +++ b/rllm/engine/rollout/verl_engine.py @@ -75,8 +75,9 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp # these go to the parser tools = kwargs.pop("tools", []) accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", "medium") - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) + prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning, reasoning_effort=reasoning_effort) request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py new file mode 100644 index 000000000..a88794d04 --- /dev/null +++ b/rllm/experimental/buffer.py @@ -0,0 +1,277 @@ +"""TrajectoryGroupBuffer for async training. + +Accumulates episodes, processes into ready-to-train trajectory groups, +with optional NVMe offloading for memory management. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import pickle +import tempfile +from dataclasses import dataclass, field + +from rllm.agents.agent import Episode, TrajectoryGroup +from rllm.experimental.common import ( + AlgorithmConfig, + CompactFilteringConfig, + RejectionSamplingConfig, + TransformConfig, + collect_reward_and_advantage_from_trajectory_groups, +) +from rllm.experimental.common.transform import transform_episodes_to_trajectory_groups +from rllm.experimental.metrics import MetricsAggregator +from rllm.experimental.sync_coordinator import SyncCoordinator +from rllm.workflows.workflow import TerminationReason + +logger = logging.getLogger(__name__) + + +_EPISODE_STRIP_KEYS = {"prompt_ids", "response_ids", "logprobs", "model_output", "routing_matrices"} +_EPISODE_STRIP_LIST_DEFAULTS = {"prompt_ids", "response_ids", "logprobs"} + + +@dataclass +class TaskBatch: + """All trajectory groups produced from one task's episodes, plus stripped episodes for UI logging.""" + groups: list[TrajectoryGroup] + episodes: list[Episode] = field(default_factory=list) + + +class TrajectoryGroupBuffer: + """Accumulates episodes, processes into trajectory groups, yields to training. + + When all rollouts for a task arrive: + 1. Record episode-level metrics to aggregator (before any filtering) + 2. Transform episodes -> trajectory groups + 3. Compact filtering + drop groups with < min_trajs_per_group + 4. Compute advantages + 5. If rejection sampling enabled: drop groups with all-zero advantage + 6. Queue the task batch for training + + Filtered groups are reported directly to the coordinator (which tracks + throttle slots and filter counts). Only non-empty task batches are queued. + All metrics flow through the shared MetricsAggregator. + + Optionally offloads pending episodes and/or queued task batches to + disk to reduce memory pressure (disabled by default). + """ + + def __init__( + self, + group_size: int, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + algorithm_config: AlgorithmConfig, + transform_config: TransformConfig, + cf_config: CompactFilteringConfig, + rs_config: RejectionSamplingConfig, + episode_offload_dir: str | None = None, + trajectory_group_offload_dir: str | None = None, + ): + self._group_size = group_size + self._coordinator = coordinator + self._aggregator = aggregator + self._algorithm_config = algorithm_config + self._transform_config = transform_config + self._cf_config = cf_config + self._rs_config = rs_config + + # Episode offloading: pending episodes serialized to disk + self._episode_offload_dir = episode_offload_dir + if episode_offload_dir: + os.makedirs(episode_offload_dir, exist_ok=True) + self._pending: dict[str, list[Episode | str]] = {} # str = offloaded file path + + # Trajectory group offloading: queued task batches serialized to disk + self._tg_offload_dir = trajectory_group_offload_dir + if trajectory_group_offload_dir: + os.makedirs(trajectory_group_offload_dir, exist_ok=True) + self._queue: asyncio.Queue[TaskBatch | str | None] = asyncio.Queue() + + async def _offload_episode(self, task_id: str, episode: Episode) -> str: + """Serialize episode to disk, return file path.""" + idx = len(self._pending.get(task_id, [])) + path = os.path.join(self._episode_offload_dir, f"{task_id}_{idx}.pkl") + await asyncio.to_thread(self._pickle_dump, path, episode) + return path + + async def _load_pending_episodes(self, task_id: str) -> list[Episode]: + """Load all pending episodes for a task, deserializing offloaded ones.""" + episodes = [] + for item in self._pending.pop(task_id, []): + if isinstance(item, str): + ep = await asyncio.to_thread(self._pickle_load, item) + episodes.append(ep) + else: + episodes.append(item) + return episodes + + async def _offload_task_batch(self, batch: TaskBatch) -> str: + """Serialize task batch to disk, return file path.""" + fd, path = tempfile.mkstemp(dir=self._tg_offload_dir, suffix=".pkl") + os.close(fd) + await asyncio.to_thread(self._pickle_dump, path, batch) + return path + + async def _load_task_batch(self, item: TaskBatch | str) -> TaskBatch: + """Load task batch, deserializing if offloaded.""" + if isinstance(item, str): + return await asyncio.to_thread(self._pickle_load, item) + return item + + @staticmethod + def _pickle_dump(path: str, obj) -> None: + with open(path, "wb") as f: + pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) + + @staticmethod + def _pickle_load(path: str): + with open(path, "rb") as f: + obj = pickle.load(f) + os.remove(path) + return obj + + async def add_episode(self, task_id: str, episode: Episode) -> bool: + """Add episode. When group completes, process and queue task batch.""" + # Offload episode to disk if enabled + if self._episode_offload_dir: + path = await self._offload_episode(task_id, episode) + self._pending.setdefault(task_id, []).append(path) + else: + self._pending.setdefault(task_id, []).append(episode) + + if len(self._pending[task_id]) < self._group_size: + return False + + # Group complete — load all episodes + if self._episode_offload_dir: + episodes = await self._load_pending_episodes(task_id) + else: + episodes = self._pending.pop(task_id, []) + + weight_version = self._min_weight_version(episodes) + + # 1. Record episode-level metrics (includes filtered tasks) + self._record_episode_metrics(episodes) + + # 2. Transform episodes -> trajectory groups + traj_groups, transform_metrics = transform_episodes_to_trajectory_groups( + episodes, self._transform_config, self._cf_config, + ) + # Strip heavy fields from episodes for UI logging, free bulk memory + for ep in episodes: + for traj in ep.trajectories: + for step in traj.steps: + for key in _EPISODE_STRIP_KEYS: + setattr(step, key, [] if key in _EPISODE_STRIP_LIST_DEFAULTS else None) + self._aggregator.record_dict(transform_metrics) + + # 3. Drop groups with too few trajectories + before_min_traj = len(traj_groups) + traj_groups = [g for g in traj_groups if len(g.trajectories) >= self._rs_config.min_trajs_per_group] + self._aggregator.record("buffer/filtered_min_trajs", before_min_traj - len(traj_groups)) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 4. Compute advantages + adv_metrics = collect_reward_and_advantage_from_trajectory_groups( + traj_groups, self._algorithm_config, + ) + self._aggregator.record_dict(adv_metrics) + + # 5. Rejection sampling: drop groups with all-zero advantage + filtered_zero_adv = 0 + if self._rs_config.filter_uniform_groups: + before_adv = len(traj_groups) + traj_groups = [ + g for g in traj_groups + if any( + abs(step.advantage) > 1e-8 + for traj in g.trajectories + for step in traj.steps + if step.advantage is not None + ) + ] + filtered_zero_adv = before_adv - len(traj_groups) + self._aggregator.record("buffer/filtered_zero_adv", filtered_zero_adv) + + if not traj_groups: + self._coordinator.on_group_filtered() + return True + + # 6. Set weight version and queue + for g in traj_groups: + g.weight_version = weight_version + + batch = TaskBatch(groups=traj_groups, episodes=episodes) + if self._tg_offload_dir: + await self._queue.put(await self._offload_task_batch(batch)) + else: + await self._queue.put(batch) + + return True + + async def get(self) -> TaskBatch | None: + """Get next task batch. Returns None when generation is done and buffer is drained.""" + item = await self._queue.get() + if item is None: + return None + return await self._load_task_batch(item) + + def mark_generation_complete(self) -> None: + """Signal that generation is finished. Flushes incomplete groups and enqueues a sentinel.""" + for task_id in list(self._pending.keys()): + items = self._pending.pop(task_id, []) + for item in items: + if isinstance(item, str): + try: + os.remove(item) + except OSError: + pass + self._coordinator.on_group_filtered() + self._queue.put_nowait(None) + + def stats(self) -> dict: + return { + "async/buffer_qsize": self._queue.qsize(), + "async/buffer_pending": len(self._pending), + } + + def _record_episode_metrics(self, episodes: list[Episode]) -> None: + """Record episode-level metrics to aggregator (all episodes, including filtered).""" + for ep in episodes: + reason = ep.termination_reason or TerminationReason.UNKNOWN + for r in TerminationReason: + self._aggregator.record( + f"episode/termination_reason/{r.value}", + 1.0 if reason == r else 0.0, + ) + for k, v in ep.metrics.items(): + try: + self._aggregator.record(f"episode/{k}", float(v)) + except (TypeError, ValueError): + continue + + # Sequence lengths and turn counts from trajectories + for traj in ep.trajectories: + n_steps = len(traj.steps) + prompt_tokens = sum(len(s.prompt_ids) for s in traj.steps) + response_tokens = sum(len(s.response_ids) for s in traj.steps) + self._aggregator.record("episode/num_turns", n_steps) + self._aggregator.record("episode/prompt_tokens", prompt_tokens) + self._aggregator.record("episode/response_tokens", response_tokens) + + @staticmethod + def _min_weight_version(episodes: list[Episode]) -> int: + min_v = float('inf') + for ep in episodes: + for traj in ep.trajectories: + for step in traj.steps: + if step.weight_version is not None: + min_v = min(min_v, step.weight_version) + return int(min_v) if min_v != float('inf') else 0 diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 4f57691b7..4537ce2a8 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -12,8 +12,8 @@ class AsyncTrainingConfig: """Controls the async training behavior spectrum. - When `enabled` is False, the trainer uses the current synchronous pipeline. - When `enabled` is True, the trainer runs concurrent generation + training + When `enable` is False, the trainer uses the current synchronous pipeline. + When `enable` is True, the trainer runs concurrent generation + training with group-level streaming and dispatch-time throttle. Behavior spectrum: @@ -23,17 +23,23 @@ class AsyncTrainingConfig: - staleness_threshold>0, partial_rollout=True: Async with partial rollout """ - enabled: bool = False + enable: bool = False mini_batch_size: int = 1 # episode groups per optimizer step - streaming_chunks: int = 1 # forward-backward passes per optimizer step (must divide mini_batch_size) + fwd_bwd_group_size: int | None = None # task batches per forward-backward pass (default: mini_batch_size) staleness_threshold: float = 0.0 # 0.0 = on-policy. Controls dispatch throttle quota. trigger_parameter_sync_step: int = 1 # optimizer steps between weight sync + version bump partial_rollout: bool = True # enable turn-level gating during weight sync + episode_offload_dir: str | None = None # NVMe offload dir for pending episodes (None = disabled) + trajectory_group_offload_dir: str | None = None # NVMe offload dir for queued task batches (None = disabled) def __post_init__(self): - if self.enabled: - assert self.streaming_chunks >= 1 - assert self.mini_batch_size % self.streaming_chunks == 0, f"mini_batch_size ({self.mini_batch_size}) must be divisible by streaming_chunks ({self.streaming_chunks})" + if self.fwd_bwd_group_size is None: + self.fwd_bwd_group_size = self.mini_batch_size + if self.enable: + assert self.fwd_bwd_group_size >= 1 + assert self.mini_batch_size % self.fwd_bwd_group_size == 0, ( + f"mini_batch_size ({self.mini_batch_size}) must be divisible by fwd_bwd_group_size ({self.fwd_bwd_group_size})" + ) @dataclass @@ -125,7 +131,7 @@ class RolloutCorrectionConfig: Backend-agnostic — each backend interprets these according to its infrastructure. Attributes: - mode: None = disabled (string loss names, current behavior). + tis_mode: None = disabled (string loss names, current behavior). "token" or "sequence" = enable custom callable loss with TIS at that level. bypass_mode: When True, use rollout (inference) logprobs as π_old — no proximal forward pass. When False, compute π_old via policy.forward() @@ -133,7 +139,7 @@ class RolloutCorrectionConfig: tis_cap: Upper clamp on the TIS importance weight. """ - mode: str | None = None + tis_mode: str | None = None bypass_mode: bool = True tis_cap: float = 5.0 @@ -179,6 +185,7 @@ class AlgorithmConfig: kl_beta: float = 0.0 eps_clip: float = 0.2 eps_clip_high: float | None = None + loss_agg_mode: Literal["token_mean", "seq_mean_token_sum", "seq_mean_token_mean", None] = None rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig) router_replay: bool = False @@ -193,7 +200,7 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": """ rc_section = config.rllm.algorithm.get("rollout_correction", {}) rollout_correction = RolloutCorrectionConfig( - mode=rc_section.get("mode", None), + tis_mode=rc_section.get("tis_mode", None), bypass_mode=rc_section.get("bypass_mode", True), tis_cap=rc_section.get("tis_cap", 5.0), ) @@ -209,6 +216,7 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": kl_beta=config.rllm.algorithm.get("kl_beta", 0.0), eps_clip=config.rllm.algorithm.get("eps_clip", 0.2), eps_clip_high=config.rllm.algorithm.get("eps_clip_high", None), + loss_agg_mode=config.rllm.algorithm.get("loss_agg_mode", None), rollout_correction=rollout_correction, router_replay=config.rllm.algorithm.get("router_replay", False), ) diff --git a/rllm/experimental/config/rllm/base.yaml b/rllm/experimental/config/rllm/base.yaml index f6f94e688..f7f6f58e8 100644 --- a/rllm/experimental/config/rllm/base.yaml +++ b/rllm/experimental/config/rllm/base.yaml @@ -56,21 +56,15 @@ algorithm: # When true, always use pre-computed step.advantage from the workflow (e.g. distillation) # and skip advantage computation (GRPO/REINFORCE). Missing advantages default to 0. use_precomputed_advantage: false - # for tinker backend only (avaiable options: importance_sampling, ppo, cispo, dro, cross_entropy) - loss_fn: null - - # Custom loss / rollout correction (used by Fireworks backend with cookbook losses) - kl_beta: 0.0 # KL penalty coefficient; >0 enables reference forward pass + loss_fn: null # [null, importance_sampling, ppo, cispo, dro, cross_entropy] + loss_agg_mode: null # [null, token-mean, seq-mean-token-sum, seq-mean-token-mean] + kl_beta: 0.0 # KL penalty coefficient eps_clip: 0.2 # PPO clip epsilon eps_clip_high: null # Asymmetric upper clip bound (null = symmetric) - - # Router Replay (R3): replay MoE expert routing from inference during training - router_replay: false - - # Rollout correction: corrects FP8 (inference) vs FP32 (training) drift + router_replay: false # Router Replay (R3): replay MoE expert routing from inference during training rollout_correction: - mode: null # null = disabled, "token" or "sequence" = enable custom callable loss bypass_mode: true # true = use rollout logprobs as pi_old (2-policy), false = proximal forward (3-policy) + tis_mode: null # null = disabled, "token" or "sequence" = TIS importance sampling level tis_cap: 5.0 # Upper clamp on TIS importance weight # Stepwise advantage @@ -124,12 +118,14 @@ sdk: # Async Training Configuration async_training: - enabled: false + enable: false mini_batch_size: 1 - streaming_chunks: 1 + fwd_bwd_group_size: null # task batches per forward-backward pass (default: mini_batch_size) staleness_threshold: 0.0 trigger_parameter_sync_step: 1 partial_rollout: true + episode_offload_dir: null # NVMe offload dir for pending episodes (null = disabled) + trajectory_group_offload_dir: null # NVMe offload dir for queued task batches (null = disabled) # Episode Logging Configuration episode_logging: diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 5084f5144..950d2974c 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -221,51 +221,6 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No return ordered_results - async def execute_tasks_streaming( - self, - tasks: list[dict], - task_ids: list[str] | None = None, - queue: asyncio.Queue | None = None, - is_validation: bool = False, - **kwargs, - ) -> None: - """Run async workflow execution, pushing each completed episode to queue immediately. - - Concurrency is bounded by the existing workflow_queue (acts as semaphore). - No episode logging, no tqdm — designed for the fully-async training path. - - Each completed episode is pushed as a tuple: (task_id, rollout_idx, result_idx, episode). - - Args: - tasks: List of task dictionaries to process. - task_ids: Optional list of task identifiers. If None, UUIDs are generated. - queue: asyncio.Queue to push completed episodes into. - is_validation: Whether the generation is for validation. - **kwargs: Additional arguments passed to individual task processing. - """ - assert queue is not None, "queue must be provided for streaming execution" - if self.workflow_queue is None: - await self.initialize_pool() - - self.rollout_engine.is_validation = is_validation - - if task_ids is None: - task_ids = [str(uuid.uuid4()) for _ in tasks] - - task_id_counter = defaultdict(int) - - async def _process_and_push(task, task_id, rollout_idx, result_idx): - result = await self.process_task_with_retry(task, task_id, rollout_idx, result_idx, **kwargs) - await queue.put(result) - - tasks_to_run = [] - for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): - rollout_idx = task_id_counter[task_id] - tasks_to_run.append(_process_and_push(task, task_id, rollout_idx, idx)) - task_id_counter[task_id] += 1 - - await asyncio.gather(*tasks_to_run) - # TODO(listar2000): eventually the agent_workflow_engine should be backend agnostic. async def execute_tasks_verl(self, batch: DataProto, is_validation: bool = False, **kwargs) -> list[Episode]: """Execute tasks from a Verl DataProto batch and return results. diff --git a/rllm/experimental/episode_buffer.py b/rllm/experimental/episode_buffer.py deleted file mode 100644 index df89e0719..000000000 --- a/rllm/experimental/episode_buffer.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Episode buffer protocol and asyncio implementation for async training. - -The buffer is a dumb pipe — no staleness filtering. Staleness is controlled -at dispatch time by SyncCoordinator's throttle quota. -""" - -from __future__ import annotations - -import asyncio -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from rllm.agents.agent import Episode - - -@dataclass -class BufferedEpisodeGroup: - """All n episodes for one prompt, collected before buffering.""" - - episodes: list[Episode] - weight_version: int # earliest weight_version across all steps in all episodes - task_id: str - - -class EpisodeGroupAccumulator: - """Per-task collector that groups episodes by task_id before pushing to buffer. - - Lives in the generation loop, NOT inside the buffer (buffer stays a dumb pipe). - Optionally filters out groups with no gradient signal (all correct or all incorrect). - """ - - def __init__( - self, - group_size: int, - buffer: EpisodeBufferProtocol, - filter_uniform_groups: bool = False, - on_group_filtered: callable | None = None, - ): - self._group_size = group_size - self._buffer = buffer - self._filter_uniform_groups = filter_uniform_groups - self._on_group_filtered = on_group_filtered - self._pending: dict[str, list[Episode]] = {} - self.total_filtered: int = 0 - - async def add_episode(self, task_id: str, episode: Episode) -> bool: - """Add episode. Returns True if group completed (pushed or filtered).""" - self._pending.setdefault(task_id, []).append(episode) - if len(self._pending[task_id]) == self._group_size: - episodes = self._pending.pop(task_id) - - if self._filter_uniform_groups and len({ep.is_correct for ep in episodes}) == 1: - self.total_filtered += 1 - if self._on_group_filtered: - self._on_group_filtered() - return True - - earliest = self._compute_earliest_version(episodes) - await self._buffer.put(BufferedEpisodeGroup(episodes=episodes, weight_version=earliest, task_id=task_id)) - return True - return False - - @staticmethod - def _compute_earliest_version(episodes: list[Episode]) -> int: - min_v = float("inf") - for ep in episodes: - for traj in ep.trajectories: - for step in traj.steps: - if step.weight_version is not None: - min_v = min(min_v, step.weight_version) - return int(min_v) if min_v != float("inf") else 0 - - -class EpisodeBufferProtocol(ABC): - """Abstract base class for episode buffers. - - Different backends can provide different implementations: - - AsyncioEpisodeBuffer: Single-threaded asyncio.Queue for Tinker - - RayEpisodeBuffer (future): Ray actor for multi-process Verl - """ - - @abstractmethod - async def put(self, item: BufferedEpisodeGroup) -> None: - """Add an episode group to the buffer.""" - - @abstractmethod - async def get(self) -> BufferedEpisodeGroup | None: - """Get next episode group. Returns None when generation is done and buffer is empty.""" - - @abstractmethod - def mark_generation_complete(self) -> None: - """Signal that generation is finished.""" - - @abstractmethod - def qsize(self) -> int: - """Current number of episode groups in the buffer.""" - - @abstractmethod - def stats(self) -> dict: - """Buffer statistics for metrics.""" - - -class AsyncioEpisodeBuffer(EpisodeBufferProtocol): - """Unbounded asyncio.Queue-based buffer for Tinker backend. - - No staleness filtering — throttle controls growth externally via SyncCoordinator. - Tinker's compute happens on remote servers, so the Python process only - orchestrates — no threading needed. - """ - - def __init__(self): - self._queue: asyncio.Queue[BufferedEpisodeGroup | None] = asyncio.Queue() # unbounded - self._generation_complete = False - self._total_produced = 0 - self._total_consumed = 0 - - async def put(self, item: BufferedEpisodeGroup) -> None: - await self._queue.put(item) - self._total_produced += 1 - - async def get(self) -> BufferedEpisodeGroup | None: - while True: - if self._generation_complete and self._queue.empty(): - return None - try: - item = await asyncio.wait_for(self._queue.get(), timeout=1.0) - except asyncio.TimeoutError: - if self._generation_complete and self._queue.empty(): - return None - continue - if item is None: # sentinel - return None - self._total_consumed += 1 - return item - - def mark_generation_complete(self) -> None: - self._generation_complete = True - - def qsize(self) -> int: - return self._queue.qsize() - - def stats(self) -> dict: - return { - "async/episode_buffer_size": self._queue.qsize(), - "async/total_produced": self._total_produced, - "async/total_consumed": self._total_consumed, - } diff --git a/rllm/experimental/metrics.py b/rllm/experimental/metrics.py new file mode 100644 index 000000000..2ee4734d6 --- /dev/null +++ b/rllm/experimental/metrics.py @@ -0,0 +1,122 @@ +"""MetricsAggregator for async training. + +Accumulates metric observations from multiple sources (buffer, training loop, +coordinator) and reduces them with per-key aggregation rules at flush time. +""" + +from __future__ import annotations + +from collections import defaultdict + +import numpy as np + + +# Keys that should be summed rather than averaged. +_SUM_KEYS: set[str] = { + "grouping/num_trajs_before_filter", + "grouping/num_trajs_after_filter", + "grouping/num_groups", + "buffer/filtered_min_trajs", + "buffer/filtered_zero_adv", +} + +# Prefixes where "last value" is the correct reduction. +_LAST_PREFIXES: tuple[str, ...] = ( + "time/", + "progress/", + "optim/", + "async/", +) + +# Prefixes where "mean" is the correct reduction. +_MEAN_PREFIXES: tuple[str, ...] = ( + "episode/", +) + + +def _infer_rule(key: str) -> str: + """Infer aggregation rule from metric key name. + + Resolution order: + 1. Explicit sum keys + 2. Prefix-based rules (last or mean) + 3. Keyword-based rules (/max, /min, /mean, /avg, /std, /fraction) + 4. Default: mean + """ + if key in _SUM_KEYS: + return "sum" + + for prefix in _LAST_PREFIXES: + if key.startswith(prefix): + return "last" + + for prefix in _MEAN_PREFIXES: + if key.startswith(prefix): + return "mean" + + # Keyword inference from the key name + if "/max" in key: + return "max" + if "/min" in key: + return "min" + if "/mean" in key or "/avg" in key: + return "mean" + if "/std" in key or "/fraction" in key: + return "mean" + + return "mean" + + +def _reduce(rule: str, values: list[float]) -> float: + if rule == "mean": + return sum(values) / len(values) + if rule == "sum": + return sum(values) + if rule == "max": + return max(values) + if rule == "min": + return min(values) + if rule == "last": + return values[-1] + return sum(values) / len(values) + + +class MetricsAggregator: + """Accumulates metric observations and flushes as an aggregated plain dict. + + Usage:: + + agg = MetricsAggregator() + + # record from various sources + agg.record("episode/queue_wait", 0.3) + agg.record("episode/queue_wait", 0.5) + agg.record_dict(transform_metrics) + + # at log time + plain_dict = agg.flush() # reduces, clears, returns dict + """ + + def __init__(self) -> None: + self._values: dict[str, list[float]] = defaultdict(list) + + def record(self, key: str, value: float) -> None: + """Record a single metric observation.""" + self._values[key].append(float(value)) + + def record_dict(self, metrics: dict) -> None: + """Record all numeric values from a dict, coercing types.""" + for k, v in metrics.items(): + if isinstance(v, (int, float)): + self._values[k].append(float(v)) + elif isinstance(v, np.number): + self._values[k].append(float(v)) + + def flush(self) -> dict[str, float]: + """Reduce all accumulated values and return a plain dict. Clears state.""" + result = {} + for key, values in self._values.items(): + if values: + result[key] = _reduce(_infer_rule(key), values) + self._values.clear() + return result diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 9a916577a..b32741fa6 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -106,30 +106,6 @@ async def generate_episodes( """ raise NotImplementedError("Subclasses must implement this method.") - async def generate_episodes_streaming( - self, - batch: TBatch, - agent_workflow_engine: UnifiedWorkflowEngine, - episode_queue, - is_validation: bool = False, - **kwargs, - ) -> None: - """Generate episodes and push each to episode_queue as it completes. - - Default: falls back to generate_episodes() and pushes all to queue. - Backends can override for true streaming. - - Args: - batch: The input batch. - agent_workflow_engine: The workflow engine to use. - episode_queue: asyncio.Queue to push (task_id, rollout_idx, result_idx, episode) tuples. - is_validation: Whether the generation is for validation. - **kwargs: Additional arguments. - """ - episodes = await self.generate_episodes(batch, agent_workflow_engine, is_validation, **kwargs) - for i, ep in enumerate(episodes): - await episode_queue.put((ep.task_id, getattr(ep, "rollout_idx", 0), i, ep)) - @abstractmethod def transform_to_backend_batch( self, diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index cbef5fd88..10ac61ace 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -8,8 +8,8 @@ @dataclass class SyncCoordinatorConfig: - mini_batch_size: int # episode groups per optimizer step - group_size: int # episodes per group (rollout.n) + mini_batch_size: int # episode groups per optimizer step + group_size: int # episodes per group (rollout.n) staleness_threshold: float trigger_parameter_sync_step: int @@ -26,7 +26,7 @@ def __init__(self, config: SyncCoordinatorConfig): self.config = config self._policy_version: int = 0 - self._outstanding: int = 0 # groups dispatched but not yet consumed by training + self._outstanding: int = 0 # groups dispatched but not yet consumed by training self._steps_since_sync: int = 0 self._total_syncs: int = 0 self._total_groups_filtered: int = 0 @@ -39,8 +39,6 @@ def __init__(self, config: SyncCoordinatorConfig): self._generation_paused: asyncio.Event = asyncio.Event() self._generation_paused.set() - self.generation_done: bool = False - @property def policy_version(self) -> int: return self._policy_version @@ -84,7 +82,7 @@ def on_sync_complete(self) -> None: self._steps_since_sync = 0 self._total_syncs += 1 - # --- Generation pause (for validation / non-partial weight sync) --- + # --- Generation pause (for validation / weight sync if partial_rollout is False) --- def pause_generation(self) -> None: self._generation_paused.clear() diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 4f24e2be4..e9ae35ddd 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -36,7 +36,8 @@ ) from rllm.experimental.common.visualization import visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine -from rllm.experimental.episode_buffer import AsyncioEpisodeBuffer, BufferedEpisodeGroup, EpisodeGroupAccumulator +from rllm.experimental.buffer import TrajectoryGroupBuffer +from rllm.experimental.metrics import MetricsAggregator from rllm.experimental.protocol import BackendProtocol from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata @@ -134,9 +135,9 @@ def __init__( # Async training config async_cfg = self.rllm_config.get("async_training", {}) self.async_config = AsyncTrainingConfig( - enabled=async_cfg.get("enabled", False), + enable=async_cfg.get("enable", False), mini_batch_size=async_cfg.get("mini_batch_size", 1), - streaming_chunks=async_cfg.get("streaming_chunks", 1), + fwd_bwd_group_size=async_cfg.get("fwd_bwd_group_size", 1), staleness_threshold=async_cfg.get("staleness_threshold", 0.0), trigger_parameter_sync_step=async_cfg.get("trigger_parameter_sync_step", 1), partial_rollout=async_cfg.get("partial_rollout", True), @@ -264,7 +265,7 @@ async def fit_async(self) -> None: async def _fit_async(self, trainer_state: TrainerState) -> None: """Dispatch to sync or concurrent training based on config.""" # TODO(listar2000): after some benchmarking, maybe we just keep the fully-async and treat on-policy as a special case. - if self.async_config.enabled: + if self.async_config.enable: await self._fit_fully_async(trainer_state) else: await self._fit_on_policy(trainer_state) @@ -387,7 +388,12 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N async def _fit_fully_async(self, trainer_state: TrainerState) -> None: """Fully-async generation + training with group-level streaming.""" - assert self.config.data.train_batch_size == 1, f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" + assert self.config.data.train_batch_size == 1, ( + f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" + ) + assert not self.agent_workflow_engine.raise_on_error, ( + "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" + ) coord_config = SyncCoordinatorConfig( mini_batch_size=self.async_config.mini_batch_size, group_size=self.rllm_config.rollout.n, @@ -395,7 +401,18 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: trigger_parameter_sync_step=self.async_config.trigger_parameter_sync_step, ) coordinator = SyncCoordinator(coord_config) - buffer = AsyncioEpisodeBuffer() + aggregator = MetricsAggregator() + buffer = TrajectoryGroupBuffer( + group_size=self.rllm_config.rollout.n, + coordinator=coordinator, + aggregator=aggregator, + algorithm_config=self.algorithm_config, + transform_config=self.transform_config, + cf_config=self.cf_config, + rs_config=self.rs_config, + episode_offload_dir=self.async_config.episode_offload_dir, + trajectory_group_offload_dir=self.async_config.trajectory_group_offload_dir, + ) # Compute total_steps for LR scheduling train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) @@ -405,20 +422,20 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: else: trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs - await asyncio.gather( - self._generation_loop(trainer_state, buffer, coordinator), - self._training_loop(trainer_state, buffer, coordinator), - ) - - async def _generation_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: - """Generate episodes and stream to buffer. Continuous fire-and-forget per prompt.""" + gen_task = asyncio.create_task(self._generation_loop(trainer_state, buffer, coordinator)) + await self._training_loop(trainer_state, buffer, coordinator, aggregator) + if not gen_task.done(): + gen_task.cancel() + try: + await gen_task + except asyncio.CancelledError: + pass + + async def _generation_loop( + self, trainer_state: TrainerState, buffer: TrajectoryGroupBuffer, coordinator: SyncCoordinator, + ) -> None: + """Generate episodes and stream to TrajectoryGroupBuffer.""" group_size = self.rllm_config.rollout.n - accumulator = EpisodeGroupAccumulator( - group_size=group_size, - buffer=buffer, - filter_uniform_groups=self.rs_config.filter_uniform_groups, - on_group_filtered=coordinator.on_group_filtered, - ) try: for epoch in range(self.rllm_config.trainer.total_epochs): @@ -426,119 +443,96 @@ async def _generation_loop(self, trainer_state: TrainerState, buffer: AsyncioEpi self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) for batch in train_dataloader: - # async training uses train_batch_size=1 task = batch[0] - # Block during validation / non-partial sync await coordinator.wait_for_generation_allowed() - - # Dispatch-time throttle: block if quota exhausted if not coordinator.has_quota(): await coordinator.wait_for_throttle() - coordinator.on_group_dispatched() - # Generate a unique task_id for this prompt task_id = str(uuid.uuid4()) - - # Fire-and-forget n rollout tasks for this prompt for rollout_idx in range(group_size): - async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): - try: - _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry(task=t, task_id=tid, rollout_idx=ridx, result_idx=0) - await accumulator.add_episode(tid, episode) - except Exception: - # Group can never complete — free the throttle slot to prevent deadlock - coordinator.on_group_consumed() - raise - + _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry( + task=t, task_id=tid, rollout_idx=ridx, result_idx=0 + ) + await buffer.add_episode(tid, episode) asyncio.create_task(_run_rollout()) - # Wait for all in-flight rollouts to finish before marking generation complete + await self._wait_for_all_workflows_idle() finally: - coordinator.generation_done = True buffer.mark_generation_complete() - async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpisodeBuffer, coordinator: SyncCoordinator) -> None: - """Collect episode groups from buffer, train with streaming grad accumulation. Runs concurrently with generation.""" + async def _training_loop( + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, + aggregator: MetricsAggregator, + ) -> None: + """Consume task batches from buffer, run forward-backward + optimizer step.""" mini_batch_size = self.async_config.mini_batch_size - streaming_chunks = self.async_config.streaming_chunks - groups_per_chunk = mini_batch_size // streaming_chunks + fwd_bwd_group_size = self.async_config.fwd_bwd_group_size + num_fwd_bwd_passes = mini_batch_size // fwd_bwd_group_size use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 rollout_engine = self.agent_workflow_engine.rollout_engine while True: trainer_state.reset_batch() step_start = time.perf_counter() - all_collected: list[BufferedEpisodeGroup] = [] + weight_versions = [] + all_trajectory_groups: list[TrajectoryGroup] = [] all_episodes: list[Episode] = [] + groups_consumed = 0 buffer_wait_time = 0.0 + done = False - # 1. Streaming gradient accumulation across chunks - for chunk_idx in range(streaming_chunks): - # Pull groups_per_chunk groups from buffer - chunk_groups: list[BufferedEpisodeGroup] = [] - while len(chunk_groups) < groups_per_chunk: - t0 = time.perf_counter() - item = await buffer.get() - buffer_wait_time += time.perf_counter() - t0 - if item is None: - break # generation done + buffer empty - chunk_groups.append(item) - - if not chunk_groups: - break + # 1. Pull mini_batch_size task batches total, split into + # num_fwd_bwd_passes forward-backward passes of fwd_bwd_group_size each. + for _ in range(num_fwd_bwd_passes): + chunk_groups: list[TrajectoryGroup] = [] + + for _ in range(fwd_bwd_group_size): + t_wait = time.perf_counter() + task_batch = await buffer.get() + buffer_wait_time += time.perf_counter() - t_wait + if task_batch is None: + done = True + break - for _ in chunk_groups: coordinator.on_group_consumed() - all_collected.extend(chunk_groups) - - # Flatten episodes from groups - episodes = [] - for group in chunk_groups: - episodes.extend(group.episodes) - all_episodes.extend(episodes) - - # Transform → rejection sampling → backend pipeline - trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups( - episodes, - self.transform_config, - self.cf_config, - traj_grouping_hook=self.traj_grouping_hook, - ) - trainer_state.trajectory_groups = trajectory_groups - trainer_state.episodes = episodes - trainer_state.metrics.update(transform_metrics) - - filtered_groups, filtered_episodes, rs_metrics = apply_rejection_sampling_and_filtering( - episodes, - trajectory_groups, - self.rs_config, - RejectionSamplingState(), - ) - trainer_state.metrics.update(rs_metrics) - trainer_state.trajectory_groups = filtered_groups - trainer_state.episodes = filtered_episodes - if not trainer_state.has_trajectory_groups: - continue + groups_consumed += 1 - await self.backend.on_batch_start(trainer_state) - trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) - await self.backend.process_backend_batch(trainer_state) - await self.backend.compute_advantages(trainer_state, self.algorithm_config) + for group in task_batch.groups: + weight_versions.append(group.weight_version) + chunk_groups.extend(task_batch.groups) + all_trajectory_groups.extend(task_batch.groups) + all_episodes.extend(task_batch.episodes) - if not all_collected: - if coordinator.generation_done and buffer.qsize() == 0: + if not chunk_groups or done: break - continue - # 2. Single optimizer step - trainer_state.episodes = all_episodes + # Forward-backward on this chunk + trainer_state.trajectory_groups = chunk_groups + + if trainer_state.has_trajectory_groups: + await self.backend.on_batch_start(trainer_state) + trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) + await self.backend.process_backend_batch(trainer_state) + + # Drain per-chunk backend metrics into aggregator + aggregator.record_dict(trainer_state.metrics) + trainer_state.metrics = {} + + # Only run optimizer step on a full batch + if groups_consumed < mini_batch_size: + break + + # 2. Optimizer step await self.backend.update_policy(trainer_state) - # 3. Training step done — check sync + # 3. Weight sync coordinator.on_training_step_complete() sync_time = 0.0 if coordinator.should_sync(): @@ -546,37 +540,20 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpiso await self._perform_weight_sync(trainer_state, coordinator, rollout_engine) sync_time = time.perf_counter() - t0 - # 4. Metrics, logging, visualization - workflow_metrics, termination_counts = self._collect_workflow_metrics_from_episodes(all_episodes) - - staleness_values = [coordinator.policy_version - g.weight_version for g in all_collected] - trainer_state.metrics["async/staleness_mean"] = np.mean(staleness_values) - trainer_state.metrics["async/staleness_min"] = np.min(staleness_values) - trainer_state.metrics["async/staleness_max"] = np.max(staleness_values) - trainer_state.metrics["async/groups_consumed"] = len(all_collected) - - # Timing - trainer_state.metrics["time/step"] = time.perf_counter() - step_start - trainer_state.metrics["time/buffer_wait"] = buffer_wait_time + # 4. Record training-loop metrics to aggregator + staleness_values = [coordinator.policy_version - v for v in weight_versions] + aggregator.record("async/staleness_mean", float(np.mean(staleness_values))) + aggregator.record("async/staleness_min", float(np.min(staleness_values))) + aggregator.record("async/staleness_max", float(np.max(staleness_values))) + aggregator.record("async/groups_consumed", groups_consumed) + aggregator.record("time/step", time.perf_counter() - step_start) + aggregator.record("time/buffer_wait", buffer_wait_time) if sync_time > 0: - trainer_state.metrics["time/weight_sync"] = sync_time - - # Weight version delta within trajectories (meaningful in partial_rollout mode) - traj_deltas = [] - for ep in all_episodes: - for traj in ep.trajectories: - versions = [s.weight_version for s in traj.steps if s.weight_version is not None] - if len(versions) >= 2: - traj_deltas.append(max(versions) - min(versions)) - if traj_deltas: - trainer_state.metrics["async/traj_weight_delta_mean"] = np.mean(traj_deltas) - trainer_state.metrics["async/traj_weight_delta_min"] = np.min(traj_deltas) - trainer_state.metrics["async/traj_weight_delta_max"] = np.max(traj_deltas) - - buffer_stats = buffer.stats() - trainer_state.metrics["async/gen_train_ratio"] = buffer_stats["async/total_produced"] / max(trainer_state.global_step, 1) - trainer_state.metrics.update(buffer_stats) - trainer_state.metrics.update(coordinator.stats()) + aggregator.record("time/weight_sync", sync_time) + + # Set all trajectory groups and stripped episodes for visualization/logging + trainer_state.trajectory_groups = all_trajectory_groups + trainer_state.episodes = all_episodes if self.tokenizer is not None and trainer_state.has_trajectory_groups: visualize_trajectory_last_steps( @@ -586,19 +563,18 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpiso show_workflow_metadata=True, ) - for key, value in workflow_metrics.items(): - trainer_state.metrics[f"batch/{key}"] = np.mean(value) - - total_counts = max(sum(termination_counts.values()), 1) - for r in TerminationReason: - trainer_state.metrics[f"batch/termination_reason/{r.value}"] = termination_counts[r.value] / total_counts - + # 5. on_batch_end writes backend metrics (progress, optim, timing) to trainer_state.metrics await self.backend.on_batch_end(trainer_state) + # 6. Flush aggregator and merge snapshots into trainer_state.metrics for logging + trainer_state.metrics.update(aggregator.flush()) + trainer_state.metrics.update(buffer.stats()) + trainer_state.metrics.update(coordinator.stats()) + self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, - episodes=all_episodes, + episodes=trainer_state.episodes, trajectory_groups=trainer_state.trajectory_groups, ) @@ -608,7 +584,6 @@ async def _training_loop(self, trainer_state: TrainerState, buffer: AsyncioEpiso trainer_state.global_step += 1 - # Check total_batches limit if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: break diff --git a/rllm/parser/chat_template_parser.py b/rllm/parser/chat_template_parser.py index 8411391e9..5f13cb970 100644 --- a/rllm/parser/chat_template_parser.py +++ b/rllm/parser/chat_template_parser.py @@ -694,7 +694,7 @@ def parse_prompt_from_messages(self, messages, add_generation_prompt=False, is_f raise NotImplementedError(f"Unsupported message role: {message['role']}") conv = Conversation.from_messages(harmony_messages) - accumulate_thinking = kwargs.get("accumulate_thinking", False) + accumulate_thinking = kwargs.get("accumulate_reasoning", kwargs.get("accumulate_thinking", False)) config = RenderConversationConfig(auto_drop_analysis=not accumulate_thinking) prompt_ids: list[int] = self.enc.render_conversation(conv, config) diff --git a/rllm/tools/tool_base.py b/rllm/tools/tool_base.py index 91b8a2ec8..446c599d2 100644 --- a/rllm/tools/tool_base.py +++ b/rllm/tools/tool_base.py @@ -1,4 +1,5 @@ import inspect +import json from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -10,10 +11,22 @@ class ToolCall: name: str arguments: dict[str, Any] + id: str | None = None + metadata: dict | None = None def to_dict(self): return {"name": self.name, "arguments": self.arguments} + def to_openai_format(self): + return { + "id": self.id or "unknown", + "type": "function", + "function": { + "name": self.name, + "arguments": json.dumps(self.arguments), + }, + } + @dataclass class ToolOutput: @@ -45,6 +58,24 @@ def to_string(self) -> str: """ return str(self) + def to_dict(self) -> dict: + """Convert the tool output to a dictionary for JSON serialization.""" + return { + "name": self.name, + "output": self.output, + "error": self.error, + "metadata": self.metadata, + } + + def to_openai_format(self) -> dict: + """Convert the tool output to OpenAI tool message format.""" + tool_call_id = (self.metadata.get("call_id") if self.metadata else None) or "unknown" + return { + "role": "tool", + "content": self.to_string(), + "tool_call_id": tool_call_id, + } + class Tool: """ diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index e0598b9dc..32f236b59 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -166,6 +166,8 @@ def get_dataloader(self, dataset: Dataset | None, trainer_state: TrainerState) - shuffle = True else: batch_size = self.full_config.data.get("val_batch_size", self.full_config.data.train_batch_size) + if batch_size == -1: + batch_size = len(dataset) shuffle = False return torch.utils.data.DataLoader( @@ -227,29 +229,6 @@ async def generate_episodes( return episodes - async def generate_episodes_streaming( - self, - batch: Any, - agent_workflow_engine: UnifiedWorkflowEngine, - episode_queue, - is_validation: bool = False, - **kwargs, - ) -> None: - """Generate episodes using streaming — push each to queue as it completes. - - Same setup as generate_episodes but uses execute_tasks_streaming. - """ - assert self.rollout_engine is not None, "rollout_engine is not initialized" - assert self.sampling_client is not None, "sampling_client is not initialized" - - self.rollout_engine.set_sampling_client(self.sampling_client) - - group_size = self.full_config.rllm.rollout.n_val if is_validation else self.full_config.rllm.rollout.n - interleaved_batch = _build_interleave_batch(batch, group_size) - task_ids = [item["uid"] for item in interleaved_batch] - - await agent_workflow_engine.execute_tasks_streaming(interleaved_batch, task_ids, queue=episode_queue, is_validation=is_validation, **kwargs) - def transform_to_backend_batch( self, trainer_state: TrainerState, diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index b016535c3..036a62b27 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -36,7 +36,7 @@ def _flatten_token_input(token_input: TinkerTokenInput) -> TinkerTokenInput: return flattened -def trajectory_to_datums(traj: Trajectory) -> list[tinker.Datum]: +def trajectory_to_datums(traj: Trajectory, router_replay: bool = False) -> list[tinker.Datum]: """ Return one or more Datum objects corresponding to the trajectory. If the sequence grows by appending, i.e., each successive observation contains @@ -61,6 +61,7 @@ class SequenceAccumulator: sampled_logprobs: list[float] = [] advantages: list[float] = [] mask: list[float] = [] + routing_matrices: list[str] = [] @classmethod def clear(cls): @@ -68,6 +69,7 @@ def clear(cls): cls.sampled_logprobs = [] cls.advantages = [] cls.mask = [] + cls.routing_matrices = [] def make_datum_from_state(): all_tokens_T = _flat_token_input_to_model_input(SequenceAccumulator.full_sequence) @@ -77,6 +79,9 @@ def make_datum_from_state(): advantages_T = SequenceAccumulator.advantages[1:] mask_T = SequenceAccumulator.mask[1:] assert input_tokens_T.length == len(target_tokens_T) == len(sampled_logprobs_T) == len(advantages_T) == len(mask_T) + if router_replay and SequenceAccumulator.routing_matrices: + rm_shifted = SequenceAccumulator.routing_matrices[1:] # match rightshift + input_tokens_T = input_tokens_T.model_copy(update={"routing_matrices": rm_shifted}) return tinker.Datum( model_input=input_tokens_T, loss_fn_inputs={ @@ -118,6 +123,12 @@ def make_datum_from_state(): SequenceAccumulator.sampled_logprobs.extend([0.0] * delta_token_input_length + output_logprobs) SequenceAccumulator.advantages.extend([0] * delta_token_input_length + advantages) SequenceAccumulator.mask.extend([0.0] * delta_token_input_length + [1.0] * len(output_token_ids)) + if router_replay: + step_rm = step.routing_matrices or [] + SequenceAccumulator.routing_matrices.extend( + [""] * delta_token_input_length + + (list(step_rm) if step_rm else [""] * len(output_token_ids)) + ) if SequenceAccumulator.full_sequence: data.append(make_datum_from_state()) @@ -137,9 +148,17 @@ def transform_trajectory_groups_to_datums( If the `estimator_map` is used in the algorithm config, we return a dictionary of datums, keyed by the trajectory group role. Otherwise, we return a list of datums. """ - # step 1: compute the advantages for each group using the common functionality - # this fills the `advantage` attribute of all the steps in the trajectory groups - adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) + # step 1: compute advantages (skip if already pre-computed by buffer) + has_advantages = any( + step.advantage is not None + for group in trajectory_groups + for traj in group.trajectories + for step in traj.steps + ) + if has_advantages: + adv_metrics = {} + else: + adv_metrics = collect_reward_and_advantage_from_trajectory_groups(trajectory_groups, algorithm_config) if algorithm_config.estimator_map: datums_dict = defaultdict(list) @@ -147,11 +166,26 @@ def transform_trajectory_groups_to_datums( datums = [] # step 2: iterate over all steps and build the Tinker Datum objects + datums_per_traj = [] + seq_lengths = [] for group in trajectory_groups: for trajectory in group.trajectories: + traj_datums = trajectory_to_datums(trajectory, router_replay=algorithm_config.router_replay) + datums_per_traj.append(len(traj_datums)) + for d in traj_datums: + seq_lengths.append(d.model_input.length) if algorithm_config.estimator_map: - datums_dict[group.group_role].extend(trajectory_to_datums(trajectory)) + datums_dict[group.group_role].extend(traj_datums) else: - datums.extend(trajectory_to_datums(trajectory)) + datums.extend(traj_datums) + + if datums_per_traj: + import numpy as _np + adv_metrics["train/datums_per_traj/mean"] = _np.mean(datums_per_traj) + adv_metrics["train/datums_per_traj/min"] = _np.min(datums_per_traj) + adv_metrics["train/datums_per_traj/max"] = _np.max(datums_per_traj) + adv_metrics["train/seq_length/mean"] = _np.mean(seq_lengths) + adv_metrics["train/seq_length/min"] = _np.min(seq_lengths) + adv_metrics["train/seq_length/max"] = _np.max(seq_lengths) return (datums if not algorithm_config.estimator_map else datums_dict), adv_metrics diff --git a/rllm/workflows/workflow.py b/rllm/workflows/workflow.py index f26ca3072..b09237ded 100644 --- a/rllm/workflows/workflow.py +++ b/rllm/workflows/workflow.py @@ -179,7 +179,7 @@ def assign_episode_correctness(self, episode: Episode) -> None: """ total_reward = 0 for trajectory in episode.trajectories: - total_reward += trajectory.reward + total_reward += trajectory.reward or 0 episode.is_correct = total_reward > 0 def collect_metrics(self, episode: Episode) -> None: @@ -192,7 +192,7 @@ def collect_metrics(self, episode: Episode) -> None: metrics = defaultdict(list) for traj in episode.trajectories: name = traj.name - metrics[name].append(traj.reward) + metrics[name].append(traj.reward or 0.0) episode.metrics = {f"{k}_acc": float(np.mean(v)) for k, v in metrics.items()} def postprocess_episode(self, episode: Episode, termination_reason: TerminationReason = None, error: dict = None) -> Episode: From f77f94a1a4e418c5443f84fc9fb082b136ad8f5c Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Mon, 30 Mar 2026 20:35:54 -0700 Subject: [PATCH 09/21] bump vllm, deepcopy msgs in Step's post_init --- pyproject.toml | 11 +++++------ rllm/agents/agent.py | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 590d5405c..4de38f6b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,16 +84,15 @@ dev = [ ] verl = [ - "torch", + "verl==0.7.1", + "vllm>=0.10.2,<=0.17.0", "transformers>=4.55.0,<5.0.0", "numpy", - "verl==0.7.1", - "ray", - "torch>=2.8.0", - "torchvision>=0.23.0", - "vllm>=0.10.2,<=0.12.0", + "torch", + "torchvision, "flash-attn>=2.8.1", "qwen-vl-utils", + "ray", ] sdk = [ diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index 3530fe5ad..b3239bea7 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -2,6 +2,7 @@ import uuid from abc import ABC, abstractmethod +from copy import deepcopy from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -53,8 +54,8 @@ def info(self) -> dict: def info(self, value: dict) -> None: self.metadata = value - # TODO: add deepcopy of chat_completions here — many agents don't deepcopy at step creation def model_post_init(self, __context: Any) -> None: + self.chat_completions = deepcopy(self.chat_completions) if self.model_output is None: return # backfill fields like prompt_ids, response_ids, logprobs, etc. From 46b3356a5efcba5bcdd21431bd1e54421616a4cb Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Mon, 30 Mar 2026 23:13:56 -0700 Subject: [PATCH 10/21] [wip] make fully-async unified trainer compatible with agent flow engines --- .../src/rllm_model_gateway/client.py | 28 ++++++++++ .../src/rllm_model_gateway/proxy.py | 44 +++++++++++++++ .../src/rllm_model_gateway/server.py | 15 +++++ rllm/experimental/engine/agent_flow_engine.py | 20 ++----- rllm/experimental/engine/gateway_manager.py | 35 +++++++++++- .../engine/remote_agent_flow_engine.py | 32 ++++++++++- rllm/experimental/protocol.py | 1 + rllm/experimental/unified_trainer.py | 56 ++++++++++++------- rllm/trainer/tinker/tinker_backend.py | 1 + 9 files changed, 194 insertions(+), 38 deletions(-) diff --git a/rllm-model-gateway/src/rllm_model_gateway/client.py b/rllm-model-gateway/src/rllm_model_gateway/client.py index 48292fe8d..e1d1a20bd 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/client.py +++ b/rllm-model-gateway/src/rllm_model_gateway/client.py @@ -140,6 +140,20 @@ def health(self) -> dict[str, Any]: resp.raise_for_status() return resp.json() + # -- Gate (weight sync) ------------------------------------------------ + + def close_gate(self) -> None: + resp = self._http.post(f"{self.gateway_url}/admin/gate/close") + resp.raise_for_status() + + def open_gate(self) -> None: + resp = self._http.post(f"{self.gateway_url}/admin/gate/open") + resp.raise_for_status() + + def wait_for_drain(self, timeout: float | None = None) -> None: + resp = self._http.post(f"{self.gateway_url}/admin/gate/drain", timeout=timeout) + resp.raise_for_status() + class AsyncGatewayClient: """Async variant of :class:`GatewayClient` using ``httpx.AsyncClient``.""" @@ -266,3 +280,17 @@ async def health(self) -> dict[str, Any]: resp = await self._http.get(f"{self.gateway_url}/health") resp.raise_for_status() return resp.json() + + # -- Gate (weight sync) ------------------------------------------------ + + async def close_gate(self) -> None: + resp = await self._http.post(f"{self.gateway_url}/admin/gate/close") + resp.raise_for_status() + + async def open_gate(self) -> None: + resp = await self._http.post(f"{self.gateway_url}/admin/gate/open") + resp.raise_for_status() + + async def wait_for_drain(self, timeout: float | None = None) -> None: + resp = await self._http.post(f"{self.gateway_url}/admin/gate/drain", timeout=timeout) + resp.raise_for_status() diff --git a/rllm-model-gateway/src/rllm_model_gateway/proxy.py b/rllm-model-gateway/src/rllm_model_gateway/proxy.py index 3d9c3430f..1e8a9e38b 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/proxy.py +++ b/rllm-model-gateway/src/rllm_model_gateway/proxy.py @@ -89,6 +89,13 @@ def __init__( self._http: httpx.AsyncClient | None = None self._pending_traces: set[asyncio.Task[None]] = set() + # Gate for weight sync: when closed, new requests wait; in-flight requests finish. + self._gate = asyncio.Event() + self._gate.set() # open by default + self._active_requests: int = 0 + self._drained = asyncio.Event() + self._drained.set() + async def start(self) -> None: self._http = httpx.AsyncClient( timeout=httpx.Timeout(timeout=None), # no timeout — LLM calls can be long @@ -106,6 +113,35 @@ async def stop(self) -> None: await self._http.aclose() self._http = None + # ------------------------------------------------------------------ + # Gate (weight sync) + # ------------------------------------------------------------------ + + def close_gate(self) -> None: + """Block new requests from proceeding. In-flight requests continue.""" + self._gate.clear() + + def open_gate(self) -> None: + """Allow new requests to proceed.""" + self._gate.set() + + async def wait_for_drain(self) -> None: + """Wait until all in-flight requests have completed.""" + if self._active_requests == 0: + return + self._drained.clear() + await self._drained.wait() + + def _on_request_start(self) -> None: + self._active_requests += 1 + self._drained.clear() + + def _on_request_end(self) -> None: + self._active_requests -= 1 + if self._active_requests <= 0: + self._active_requests = 0 + self._drained.set() + # ------------------------------------------------------------------ # Main entrypoint # ------------------------------------------------------------------ @@ -116,6 +152,14 @@ async def _ensure_started(self) -> None: async def handle(self, request: Request) -> Response: """Proxy *request* to an inference worker, capture trace, return response.""" + await self._gate.wait() + self._on_request_start() + try: + return await self._handle_inner(request) + finally: + self._on_request_end() + + async def _handle_inner(self, request: Request) -> Response: await self._ensure_started() session_id: str | None = request.state.session_id originally_requested_logprobs: bool = getattr(request.state, "originally_requested_logprobs", False) diff --git a/rllm-model-gateway/src/rllm_model_gateway/server.py b/rllm-model-gateway/src/rllm_model_gateway/server.py index 74b3cb18b..55ba85f6f 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/server.py +++ b/rllm-model-gateway/src/rllm_model_gateway/server.py @@ -270,6 +270,21 @@ async def reload(): # Placeholder for hot-reload return {"status": "ok"} + @app.post("/admin/gate/close") + async def gate_close(): + proxy.close_gate() + return {"status": "closed"} + + @app.post("/admin/gate/open") + async def gate_open(): + proxy.open_gate() + return {"status": "open"} + + @app.post("/admin/gate/drain") + async def gate_drain(): + await proxy.wait_for_drain() + return {"status": "drained"} + # -- Proxy catch-all (must be last) ------------------------------------ @app.api_route( diff --git a/rllm/experimental/engine/agent_flow_engine.py b/rllm/experimental/engine/agent_flow_engine.py index 60c3ccdd6..fc2a99887 100644 --- a/rllm/experimental/engine/agent_flow_engine.py +++ b/rllm/experimental/engine/agent_flow_engine.py @@ -117,7 +117,7 @@ async def execute_tasks( for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - futures.append(self._process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) with tqdm(total=len(tasks), desc="Generating trajectories") as pbar: for future in asyncio.as_completed(futures): @@ -141,7 +141,7 @@ async def execute_tasks( return ordered_results - async def _process_task_with_retry( + async def process_task_with_retry( self, task: dict, task_id: str, @@ -202,12 +202,8 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) - """Run a single AgentFlow task: execute, evaluate, enrich.""" loop = asyncio.get_event_loop() - # 1. Create gateway session (run in executor to avoid blocking event loop) - await loop.run_in_executor( - self.executor, - self.gateway.create_session, - uid, - ) + # 1. Create gateway session + await self.gateway.acreate_session(uid, is_validation=is_validation) session_url = self.gateway.get_session_url(uid) # 2. Build config @@ -239,12 +235,8 @@ async def _run_single(self, task: dict, uid: str, is_validation: bool = False) - traj.reward = eval_output.reward episode.is_correct = eval_output.is_correct - # 5. Retrieve traces from gateway (run in executor to avoid blocking event loop) - traces = await loop.run_in_executor( - self.executor, - self.gateway.get_traces, - uid, - ) + # 5. Retrieve traces from gateway + traces = await self.gateway.aget_traces(uid) # 6. Enrich episode with token data enriched = self._enrich_episode(episode, traces, uid, task) diff --git a/rllm/experimental/engine/gateway_manager.py b/rllm/experimental/engine/gateway_manager.py index 44e75b474..68524acca 100644 --- a/rllm/experimental/engine/gateway_manager.py +++ b/rllm/experimental/engine/gateway_manager.py @@ -18,7 +18,7 @@ import time from typing import TYPE_CHECKING, Any -from rllm_model_gateway.client import GatewayClient +from rllm_model_gateway.client import AsyncGatewayClient, GatewayClient from rllm_model_gateway.models import TraceRecord if TYPE_CHECKING: @@ -89,6 +89,7 @@ def __init__(self, config: DictConfig, mode: str = "thread") -> None: self._server: Any = None # uvicorn.Server when using thread mode self._local_handler: Any = None # in-process handler for tinker self._client: GatewayClient | None = None + self._async_client: AsyncGatewayClient | None = None # Per-mode sampling params (extracted from rollout engine in start()) self._train_sampling_params: dict[str, Any] = {} @@ -100,10 +101,18 @@ def gateway_url(self) -> str: @property def client(self) -> GatewayClient: + """Sync client for lifecycle operations (start, stop, health polling).""" if self._client is None: self._client = GatewayClient(self.gateway_url) return self._client + @property + def async_client(self) -> AsyncGatewayClient: + """Async client for runtime operations (sessions, traces).""" + if self._async_client is None: + self._async_client = AsyncGatewayClient(self.gateway_url) + return self._async_client + # -- Lifecycle ----------------------------------------------------------- def start(self, rollout_engine: RolloutEngine) -> None: @@ -158,6 +167,20 @@ def stop(self) -> None: self._local_handler = None + # -- Gate (weight sync) --------------------------------------------------- + + def close_gate(self) -> None: + """Stop forwarding new inference requests through the gateway.""" + self.client.close_gate() + + async def wait_for_drain(self) -> None: + """Wait for all in-flight inference requests to complete.""" + await self.async_client.wait_for_drain() + + def open_gate(self) -> None: + """Resume forwarding inference requests through the gateway.""" + self.client.open_gate() + # -- Session / trace API ------------------------------------------------- def create_session(self, session_id: str, is_validation: bool = False) -> str: @@ -171,6 +194,16 @@ def get_traces(self, session_id: str) -> list[TraceRecord]: self.client.flush() return self.client.get_session_traces(session_id) + # -- Async session / trace API ------------------------------------------- + + async def acreate_session(self, session_id: str, is_validation: bool = False) -> str: + sp = self._val_sampling_params if is_validation else self._train_sampling_params + return await self.async_client.create_session(session_id=session_id, sampling_params=sp or None) + + async def aget_traces(self, session_id: str) -> list[TraceRecord]: + await self.async_client.flush() + return await self.async_client.get_session_traces(session_id) + # -- Worker setup -------------------------------------------------------- def _ensure_workers(self, rollout_engine: RolloutEngine) -> list[str]: diff --git a/rllm/experimental/engine/remote_agent_flow_engine.py b/rllm/experimental/engine/remote_agent_flow_engine.py index c59adc814..56483711b 100644 --- a/rllm/experimental/engine/remote_agent_flow_engine.py +++ b/rllm/experimental/engine/remote_agent_flow_engine.py @@ -76,7 +76,7 @@ async def execute_tasks( uid = f"{task_id}:{rollout_idx}" session_id = str(uuid.uuid4()) - self.gateway.create_session(session_id, is_validation=is_validation) + await self.gateway.acreate_session(session_id, is_validation=is_validation) session_url = self.gateway.get_session_url(session_id) submissions.append( @@ -101,7 +101,7 @@ async def execute_tasks( if not result.finished: logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) result.reward = 0.0 - traces = self.gateway.get_traces(result.session_id) + traces = await self.gateway.aget_traces(result.session_id) episode = _build_episode(traces, result, uid, task) if not result.finished: episode.metadata["error"] = {"message": result.error or "Unknown error"} @@ -123,6 +123,34 @@ async def execute_tasks( return episodes + async def process_task_with_retry( + self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs, + ) -> tuple[str, int, int, Episode]: + """Process a single task: create session, submit to runtime, retrieve traces, build episode.""" + uid = f"{task_id}:{rollout_idx}" + session_id = str(uuid.uuid4()) + is_validation = kwargs.get("is_validation", False) + + await self.gateway.acreate_session(session_id, is_validation=is_validation) + session_url = self.gateway.get_session_url(session_id) + + submission = TaskSubmission( + task=task, session_id=session_id, task_id=task_id, inference_url=session_url, + ) + results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) + result = results[0] + + if not result.finished: + logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) + result.reward = 0.0 + + traces = await self.gateway.aget_traces(session_id) + episode = _build_episode(traces, result, uid, task) + if not result.finished: + episode.metadata["error"] = {"message": result.error or "Unknown error"} + + return task_id, rollout_idx, result_idx, episode + def shutdown(self) -> None: """No local resources to clean up (runtime shutdown is separate).""" pass diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index b32741fa6..e529f2aad 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -37,6 +37,7 @@ class BackendProtocol(ABC, Generic[TDataset, TBatch]): name: str = "base_backend" requires_loop: bool = False + needs_weight_sync_gate: bool = True def __init__(self, config: DictConfig, **kwargs): """Initialize the backend. diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index dfada249b..d2fb60e47 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -240,6 +240,9 @@ def __init__( if hasattr(self.backend, "tokenizer"): self.tokenizer = self.backend.tokenizer + # Tracks in-flight async rollout tasks for drain/wait logic + self._in_flight_tasks: set[asyncio.Task] = set() + def _validate_and_setup_configs(self): """Validate and setup common configs.""" # validate common, backend-agnostic configs @@ -473,7 +476,7 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: assert self.config.data.train_batch_size == 1, ( f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" ) - assert not self.agent_workflow_engine.raise_on_error, ( + assert not getattr(self.agent_workflow_engine, "raise_on_error", False), ( "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" ) coord_config = SyncCoordinatorConfig( @@ -539,10 +542,11 @@ async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): task=t, task_id=tid, rollout_idx=ridx, result_idx=0 ) await buffer.add_episode(tid, episode) - asyncio.create_task(_run_rollout()) - + t = asyncio.create_task(_run_rollout()) + self._in_flight_tasks.add(t) + t.add_done_callback(self._in_flight_tasks.discard) - await self._wait_for_all_workflows_idle() + await self._wait_for_drain() finally: buffer.mark_generation_complete() @@ -558,7 +562,7 @@ async def _training_loop( fwd_bwd_group_size = self.async_config.fwd_bwd_group_size num_fwd_bwd_passes = mini_batch_size // fwd_bwd_group_size use_total_batches = self.rllm_config.trainer.get("total_batches", -1) > 0 - rollout_engine = self.agent_workflow_engine.rollout_engine + rollout_engine = getattr(self.agent_workflow_engine, "rollout_engine", None) while True: trainer_state.reset_batch() @@ -669,44 +673,54 @@ async def _training_loop( if use_total_batches and trainer_state.global_step >= self.rllm_config.trainer.total_batches: break - async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine) -> None: + async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine | None) -> None: """Synchronize weights between training and rollout engines. - Two modes depending on partial_rollout: - - partial_rollout=True: Uses rollout engine gate (model-call level). + Gating behavior depends on backend.needs_weight_sync_gate: + - False (e.g. Tinker): skip gating, just update weights in-place. + - True + partial_rollout=True: gate at model-call level (rollout engine or gateway). Workflows block between turns, resume with new weights. - - partial_rollout=False: Uses coordinator generation pause (dispatch level). + - True + partial_rollout=False: pause at dispatch level (coordinator). Workflows finish naturally, gate stays open. """ + gateway = getattr(self.agent_workflow_engine, "gateway", None) + if self.async_config.partial_rollout: - # Block new model calls; in-flight calls finish, workflows pause between turns - rollout_engine.close_gate() - await rollout_engine.wait_for_drain() + if self.backend.needs_weight_sync_gate: + if rollout_engine is not None: + rollout_engine.close_gate() + await rollout_engine.wait_for_drain() + elif gateway is not None: + gateway.close_gate() + await gateway.wait_for_drain() else: - # Stop dispatching new prompts, let all workflows finish naturally coordinator.pause_generation() - await self._wait_for_all_workflows_idle() + await self._wait_for_drain() trainer_state.policy_version = coordinator.policy_version + 1 await self.backend.on_policy_updated(trainer_state) - rollout_engine.weight_version = trainer_state.policy_version + if rollout_engine is not None: + rollout_engine.weight_version = trainer_state.policy_version coordinator.on_sync_complete() if self.async_config.partial_rollout: - rollout_engine.open_gate() + if self.backend.needs_weight_sync_gate: + if rollout_engine is not None: + rollout_engine.open_gate() + elif gateway is not None: + gateway.open_gate() else: coordinator.resume_generation() - async def _wait_for_all_workflows_idle(self) -> None: - """Wait for all n_parallel_tasks workflows to return to the pool.""" - pool = self.agent_workflow_engine - while pool.workflow_queue.qsize() < pool.n_parallel_tasks: + async def _wait_for_drain(self) -> None: + """Wait for all in-flight rollout tasks to complete.""" + while self._in_flight_tasks: await asyncio.sleep(0.1) async def _validate_async_with_pause(self, trainer_state: TrainerState, coordinator: SyncCoordinator) -> dict: """Validation with dispatch-level pause. Waits for workflows to drain, then runs validation.""" coordinator.pause_generation() - await self._wait_for_all_workflows_idle() + await self._wait_for_drain() try: return await self._validate_async(trainer_state) finally: diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index fe1654368..2ae85a2a8 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -67,6 +67,7 @@ class TinkerBackend(BackendProtocol[Iterable, list[tinker.Datum]]): name: str = "tinker" requires_loop: bool = True # Tinker uses async operations + needs_weight_sync_gate: bool = False # Tinker swaps sampling_client in-place, no gating needed def __init__( self, From 497d35a0e8d5da5f272a492bdb08f11a02cd55a0 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Mon, 30 Mar 2026 23:20:29 -0700 Subject: [PATCH 11/21] fix staleness thottling --- rllm/experimental/sync_coordinator.py | 38 ++++++++++++++++++--------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index 10ac61ace..86a0dfc6a 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -15,23 +15,29 @@ class SyncCoordinatorConfig: @property def max_rollout_quota(self) -> int: - """Max outstanding groups (dispatched but not yet consumed by training).""" + """Max dispatches per sync window (Verl/AReaL formulation).""" return int((1 + self.staleness_threshold) * self.trigger_parameter_sync_step * self.mini_batch_size) class SyncCoordinator: - """Coordinates rollout scheduling and parameter sync between generation and training loops.""" + """Coordinates rollout scheduling and parameter sync between generation and training loops. + + Uses a per-sync-window dispatch counter (matching Verl/AReaL). The counter + resets only on weight sync, not on consume. This guarantees zero staleness + when staleness_threshold=0. + """ def __init__(self, config: SyncCoordinatorConfig): self.config = config self._policy_version: int = 0 - self._outstanding: int = 0 # groups dispatched but not yet consumed by training + self._dispatched_since_sync: int = 0 # groups dispatched in current sync window + self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered self._steps_since_sync: int = 0 self._total_syncs: int = 0 self._total_groups_filtered: int = 0 - # Throttle — blocks generation when outstanding >= max_rollout_quota + # Throttle — blocks generation when dispatched_since_sync >= max_rollout_quota self._throttle_event: asyncio.Event = asyncio.Event() self._throttle_event.set() @@ -47,27 +53,27 @@ def policy_version(self) -> int: def on_group_dispatched(self) -> None: """Generation loop dispatched one prompt (n rollouts).""" - self._outstanding += 1 - if self._outstanding >= self.config.max_rollout_quota: + self._dispatched_since_sync += 1 + self._in_flight += 1 + if self._dispatched_since_sync >= self.config.max_rollout_quota: self._throttle_event.clear() def on_group_consumed(self) -> None: """Training loop consumed one group from the buffer.""" - self._outstanding = max(0, self._outstanding - 1) - self._throttle_event.set() + self._in_flight = max(0, self._in_flight - 1) def on_group_filtered(self) -> None: - """Accumulator filtered out a uniform group. Frees throttle slot and tracks count.""" + """Accumulator filtered out a group. Decrements in-flight count and tracks stats.""" self._total_groups_filtered += 1 - self.on_group_consumed() + self._in_flight = max(0, self._in_flight - 1) async def wait_for_throttle(self) -> None: - """Generation loop blocks here when quota is full.""" + """Generation loop blocks here when dispatch window is full.""" await self._throttle_event.wait() def has_quota(self) -> bool: """Whether the generation loop can dispatch another group.""" - return self._outstanding < self.config.max_rollout_quota + return self._dispatched_since_sync < self.config.max_rollout_quota # --- Weight sync --- @@ -81,6 +87,11 @@ def on_sync_complete(self) -> None: self._policy_version += 1 self._steps_since_sync = 0 self._total_syncs += 1 + # Reset dispatch window. In-flight items span the sync boundary — + # they were dispatched with old weights and count toward the new window. + self._dispatched_since_sync = self._in_flight + if self._dispatched_since_sync < self.config.max_rollout_quota: + self._throttle_event.set() # --- Generation pause (for validation / weight sync if partial_rollout is False) --- @@ -96,7 +107,8 @@ async def wait_for_generation_allowed(self) -> None: def stats(self) -> dict: return { "async/policy_version": self._policy_version, - "async/outstanding_groups": self._outstanding, + "async/dispatched_since_sync": self._dispatched_since_sync, + "async/in_flight_groups": self._in_flight, "async/steps_since_sync": self._steps_since_sync, "async/max_rollout_quota": self.config.max_rollout_quota, "async/total_syncs": self._total_syncs, From 8170c7af22b96f71ab713c1f8933e87148646117 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Tue, 31 Mar 2026 00:33:41 -0700 Subject: [PATCH 12/21] enfore concurrency across engines --- rllm/experimental/engine/agent_flow_engine.py | 96 ++++++++++--------- .../engine/remote_agent_flow_engine.py | 94 +++++++----------- rllm/experimental/unified_trainer.py | 1 + 3 files changed, 82 insertions(+), 109 deletions(-) diff --git a/rllm/experimental/engine/agent_flow_engine.py b/rllm/experimental/engine/agent_flow_engine.py index fc2a99887..1b582dc31 100644 --- a/rllm/experimental/engine/agent_flow_engine.py +++ b/rllm/experimental/engine/agent_flow_engine.py @@ -74,6 +74,7 @@ def __init__( self.raise_on_error = raise_on_error self.episode_logger = episode_logger self.executor = ThreadPoolExecutor(max_workers=n_parallel_tasks) + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Raise the file descriptor limit to avoid "Too many open files" when # running many parallel agent flows with individual HTTP clients. @@ -150,53 +151,54 @@ async def process_task_with_retry( is_validation: bool = False, ) -> tuple[str, int, int, Episode]: """Process a single task with retry logic.""" - for retry_attempt in range(1, self.retry_limit + 1): - uid = f"{task_id}:{rollout_idx}" - try: - episode = await self._run_single(task, uid, is_validation=is_validation) - episode.id = uid - episode.task = task - - # Display rewards - reward_strs = [] - for traj in episode.trajectories: - reward = "N/A" - if traj.reward is not None: - reward = f"{traj.reward:.1f}" - elif len(traj.steps) > 0: - reward = f"{traj.steps[-1].reward:.1f}" - reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", - ) - - return task_id, rollout_idx, result_idx, episode - - except Exception as e: - logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) - if retry_attempt < self.retry_limit: - continue - - if self.raise_on_error: - raise - - # Return an error episode - return ( - task_id, - rollout_idx, - result_idx, - Episode( - id=uid, - task=task, - is_correct=False, - termination_reason=TerminationReason.ERROR, - metadata={"error": {"message": str(e)}}, - ), - ) - - # Should not reach here, but satisfy type checker - raise RuntimeError(f"[{uid}] Exhausted all retries") + async with self._semaphore: + for retry_attempt in range(1, self.retry_limit + 1): + uid = f"{task_id}:{rollout_idx}" + try: + episode = await self._run_single(task, uid, is_validation=is_validation) + episode.id = uid + episode.task = task + + # Display rewards + reward_strs = [] + for traj in episode.trajectories: + reward = "N/A" + if traj.reward is not None: + reward = f"{traj.reward:.1f}" + elif len(traj.steps) > 0: + reward = f"{traj.steps[-1].reward:.1f}" + reward_strs.append(f"{traj.name}: {reward}") + colorful_print( + f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", + fg="green" if episode.is_correct else "yellow", + ) + + return task_id, rollout_idx, result_idx, episode + + except Exception as e: + logger.error("[%s] Attempt %d/%d failed: %s", uid, retry_attempt, self.retry_limit, e) + if retry_attempt < self.retry_limit: + continue + + if self.raise_on_error: + raise + + # Return an error episode + return ( + task_id, + rollout_idx, + result_idx, + Episode( + id=uid, + task=task, + is_correct=False, + termination_reason=TerminationReason.ERROR, + metadata={"error": {"message": str(e)}}, + ), + ) + + # Should not reach here, but satisfy type checker + raise RuntimeError(f"[{uid}] Exhausted all retries") async def _run_single(self, task: dict, uid: str, is_validation: bool = False) -> Episode: """Run a single AgentFlow task: execute, evaluate, enrich.""" diff --git a/rllm/experimental/engine/remote_agent_flow_engine.py b/rllm/experimental/engine/remote_agent_flow_engine.py index 56483711b..327901d85 100644 --- a/rllm/experimental/engine/remote_agent_flow_engine.py +++ b/rllm/experimental/engine/remote_agent_flow_engine.py @@ -5,6 +5,7 @@ converting gateway traces to training Steps. """ +import asyncio import logging import uuid from collections import defaultdict @@ -31,12 +32,15 @@ def __init__( runtime: RemoteAgentRuntime, gateway: GatewayManager, session_timeout: float = 900.0, + n_parallel_tasks: int = 128, episode_logger: EpisodeLogger | None = None, ) -> None: self.runtime = runtime self.gateway = gateway self.session_timeout = session_timeout + self.n_parallel_tasks = n_parallel_tasks self.episode_logger = episode_logger + self._semaphore = asyncio.Semaphore(n_parallel_tasks) # Training step tracking (set by set_training_step) self.current_step = 0 @@ -55,59 +59,24 @@ async def execute_tasks( is_validation: bool = False, **kwargs, ) -> list[Episode]: - """Submit tasks to remote runtime, gather results, build Episodes from gateway traces. - - 1. Prepare submissions (create gateway sessions) - 2. Submit all and gather results concurrently via runtime - 3. Retrieve traces from gateway + build Episodes - """ + """Submit tasks to remote runtime, gather results, build Episodes from gateway traces.""" if task_ids is None: task_ids = [str(uuid.uuid4()) for _ in tasks] - # Phase 1: Prepare submissions task_id_counter: dict[str, int] = defaultdict(int) - submissions: list[TaskSubmission] = [] - # Map session_id -> (idx, uid, task) for result correlation - session_metadata: dict[str, tuple[int, str, dict]] = {} + results: list[Episode | None] = [None] * len(tasks) + futures = [] for idx, (task, task_id) in enumerate(zip(tasks, task_ids, strict=True)): rollout_idx = task_id_counter[task_id] task_id_counter[task_id] += 1 - uid = f"{task_id}:{rollout_idx}" - session_id = str(uuid.uuid4()) - - await self.gateway.acreate_session(session_id, is_validation=is_validation) - session_url = self.gateway.get_session_url(session_id) + futures.append(self.process_task_with_retry(task, task_id, rollout_idx, idx, is_validation=is_validation)) - submissions.append( - TaskSubmission( - task=task, - session_id=session_id, - task_id=task_id, - inference_url=session_url, - ) - ) - session_metadata[session_id] = (idx, uid, task) + for future in asyncio.as_completed(futures): + task_id, rollout_idx, idx, episode = await future + results[idx] = episode - # Phase 2: Submit all and gather results concurrently - logger.info("Submitting %d tasks to remote runtime (timeout=%.0fs)", len(submissions), self.session_timeout) - remote_results = await self.runtime.execute_tasks(submissions, timeout=self.session_timeout) - - # Phase 3: Retrieve traces from gateway + build Episodes (match by session_id) - episode_map: dict[int, Episode] = {} - - for result in remote_results: - idx, uid, task = session_metadata[result.session_id] - if not result.finished: - logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) - result.reward = 0.0 - traces = await self.gateway.aget_traces(result.session_id) - episode = _build_episode(traces, result, uid, task) - if not result.finished: - episode.metadata["error"] = {"message": result.error or "Unknown error"} - episode_map[idx] = episode - - episodes = [episode_map[i] for i in range(len(tasks))] + episodes: list[Episode] = results # type: ignore[assignment] # Log episodes if logger is provided if self.episode_logger is not None: @@ -126,30 +95,31 @@ async def execute_tasks( async def process_task_with_retry( self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs, ) -> tuple[str, int, int, Episode]: - """Process a single task: create session, submit to runtime, retrieve traces, build episode.""" - uid = f"{task_id}:{rollout_idx}" - session_id = str(uuid.uuid4()) - is_validation = kwargs.get("is_validation", False) + """Process a single task with concurrency control.""" + async with self._semaphore: + uid = f"{task_id}:{rollout_idx}" + session_id = str(uuid.uuid4()) + is_validation = kwargs.get("is_validation", False) - await self.gateway.acreate_session(session_id, is_validation=is_validation) - session_url = self.gateway.get_session_url(session_id) + await self.gateway.acreate_session(session_id, is_validation=is_validation) + session_url = self.gateway.get_session_url(session_id) - submission = TaskSubmission( - task=task, session_id=session_id, task_id=task_id, inference_url=session_url, - ) - results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) - result = results[0] + submission = TaskSubmission( + task=task, session_id=session_id, task_id=task_id, inference_url=session_url, + ) + results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) + result = results[0] - if not result.finished: - logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) - result.reward = 0.0 + if not result.finished: + logger.warning("[%s] Remote task failed (assigning reward=0): %s", uid, result.error) + result.reward = 0.0 - traces = await self.gateway.aget_traces(session_id) - episode = _build_episode(traces, result, uid, task) - if not result.finished: - episode.metadata["error"] = {"message": result.error or "Unknown error"} + traces = await self.gateway.aget_traces(session_id) + episode = _build_episode(traces, result, uid, task) + if not result.finished: + episode.metadata["error"] = {"message": result.error or "Unknown error"} - return task_id, rollout_idx, result_idx, episode + return task_id, rollout_idx, result_idx, episode def shutdown(self) -> None: """No local resources to clean up (runtime shutdown is separate).""" diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index d2fb60e47..e0d538e46 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -221,6 +221,7 @@ def __init__( runtime=self._remote_runtime, gateway=self._gateway, session_timeout=remote_runtime_config.session_timeout, + n_parallel_tasks=self.rllm_config.workflow.n_parallel_tasks, episode_logger=self.episode_logger, ) else: From 2f8e2f1cfe91089f89a439e09620ddc7efc416b2 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Fri, 3 Apr 2026 15:26:08 -0700 Subject: [PATCH 13/21] fix fully async, refactor metrics --- .../train_countdown_unified_tinker.py | 28 +++++ .../train_countdown_unified_tinker_async.sh | 36 ++++++ .../train_countdown_unified_tinker_sync.sh | 31 ++++++ pyproject.toml | 2 +- rllm/agents/agent.py | 3 + rllm/engine/rollout/rollout_engine.py | 19 +++- rllm/engine/rollout/tinker_engine.py | 3 + rllm/experimental/buffer.py | 34 +++--- rllm/experimental/common/transform.py | 26 ++--- rllm/experimental/common/visualization.py | 34 ++++++ .../engine/unified_workflow_engine.py | 16 ++- rllm/experimental/metrics.py | 12 +- rllm/experimental/sync_coordinator.py | 30 +++-- rllm/experimental/unified_trainer.py | 104 ++++++++++++------ rllm/parser/__init__.py | 28 ++--- rllm/rewards/countdown_reward.py | 20 ++-- rllm/trainer/tinker/tinker_backend.py | 4 - rllm/trainer/tinker/tinker_metrics_utils.py | 60 +--------- rllm/trainer/tinker/tinker_policy_trainer.py | 13 ++- rllm/trainer/tinker/transform.py | 18 +-- 20 files changed, 331 insertions(+), 190 deletions(-) create mode 100644 examples/countdown/unified_trainer/train_countdown_unified_tinker.py create mode 100644 examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh create mode 100644 examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker.py b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py new file mode 100644 index 000000000..9b114ed79 --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py @@ -0,0 +1,28 @@ +import hydra + +from rllm.data.dataset import DatasetRegistry +from rllm.rewards.countdown_reward import countdown_reward_fn +from rllm.experimental.unified_trainer import AgentTrainer +from rllm.workflows.simple_workflow import SimpleWorkflow + + +@hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None) +def main(config): + train_dataset = DatasetRegistry.load_dataset("countdown", "train") + test_dataset = DatasetRegistry.load_dataset("countdown", "test") + + trainer = AgentTrainer( + workflow_class=SimpleWorkflow, + workflow_args={ + "reward_function": countdown_reward_fn, + }, + config=config, + train_dataset=train_dataset, + val_dataset=test_dataset, + backend="tinker", + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh new file mode 100644 index 000000000..1f3e8586b --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh @@ -0,0 +1,36 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=1 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=true \ + rllm.async_training.mini_batch_size=32 \ + rllm.async_training.fwd_bwd_group_size=8 \ + rllm.async_training.staleness_threshold=0.5 \ + rllm.async_training.trigger_parameter_sync_step=1 \ + rllm.async_training.partial_rollout=true \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh new file mode 100644 index 000000000..a9d2748fa --- /dev/null +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh @@ -0,0 +1,31 @@ +set -x + +python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ + rllm/backend=tinker \ + model.name=Qwen/Qwen3-8B \ + model.lora_rank=32 \ + training.group_size=8 \ + training.learning_rate=2e-5 \ + training.max_length=4096 \ + sampling.train.temperature=1.0 \ + sampling.train.top_p=1.0 \ + sampling.val.temperature=1.0 \ + sampling.val.top_p=1.0 \ + validation.group_size=1 \ + rllm.workflow.n_parallel_tasks=256 \ + rllm.workflow.retry_limit=1 \ + rllm.workflow.raise_on_error=false \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.train_batch_size=32 \ + data.val_batch_size=1024 \ + rllm.algorithm.adv_estimator=grpo \ + rllm.algorithm.norm_adv_by_std_in_grpo=true \ + rllm.async_training.enable=false \ + rllm.trainer.total_epochs=1 \ + rllm.trainer.logger='[wandb]' \ + rllm.trainer.project_name='rllm-countdown' \ + rllm.trainer.experiment_name='countdown-tinker-sync' \ + rllm.trainer.val_before_train=true \ + rllm.trainer.test_freq=10 \ + rllm.trainer.save_freq=-1 diff --git a/pyproject.toml b/pyproject.toml index 4de38f6b0..c1a8622a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ verl = [ "transformers>=4.55.0,<5.0.0", "numpy", "torch", - "torchvision, + "torchvision", "flash-attn>=2.8.1", "qwen-vl-utils", "ray", diff --git a/rllm/agents/agent.py b/rllm/agents/agent.py index b3239bea7..f72e334d3 100644 --- a/rllm/agents/agent.py +++ b/rllm/agents/agent.py @@ -92,6 +92,7 @@ def _serialize_value(value): "prompt_ids": self.prompt_ids, "response_ids": self.response_ids, "logprobs": self.logprobs, + "routing_matrices": self.routing_matrices, "chat_completions": _serialize_value(self.chat_completions), "observation": self.observation, "thought": self.thought, @@ -114,6 +115,7 @@ def from_dict(cls, data: dict) -> Step: prompt_ids=data["prompt_ids"], response_ids=data["response_ids"], logprobs=data["logprobs"], + routing_matrices=data.get("routing_matrices"), chat_completions=data["chat_completions"], observation=data["observation"], thought=data["thought"], @@ -140,6 +142,7 @@ def from_model_output(cls, model_output: ModelOutput, messages: list[dict] | Non action=action, model_response=model_output.content or "", model_output=model_output, + weight_version=model_output.weight_version, ) diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index c7cf14ebf..4e1b98ab9 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -1,10 +1,13 @@ import asyncio +import logging from dataclasses import dataclass from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput from rllm.parser import ChatTemplateParser from rllm.tools.tool_base import ToolCall +logger = logging.getLogger(__name__) + @dataclass class ModelOutput: @@ -22,6 +25,7 @@ class ModelOutput: completion_length: int = 0 finish_reason: str | None = None weight_version: int | None = None # policy version at time of generation + metrics: dict | None = None # per-turn server metrics (e.g. ttft, queue durations) def to_dict(self): return { @@ -38,6 +42,7 @@ def to_dict(self): "completion_length": self.completion_length, "finish_reason": self.finish_reason, "weight_version": self.weight_version, + "metrics": self.metrics, } @classmethod @@ -56,6 +61,7 @@ def from_dict(cls, data: dict): completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), weight_version=data.get("weight_version"), + metrics=data.get("metrics"), ) @@ -77,10 +83,12 @@ def __init__(self, *args, **kwargs): def close_gate(self) -> None: """Close the gate. New model calls will block at wait_for_gate().""" + logger.info(f"[RolloutEngine] Closing gate. Active calls: {self._active_calls}") self._gate.clear() def open_gate(self) -> None: """Open the gate, releasing any blocked model calls.""" + logger.info(f"[RolloutEngine] Opening gate. Active calls: {self._active_calls}") self._gate.set() def on_model_call_complete(self) -> None: @@ -89,15 +97,23 @@ def on_model_call_complete(self) -> None: if self._active_calls <= 0: self._active_calls = 0 self._drained_event.set() + logger.debug("[RolloutEngine] All active calls drained.") + else: + logger.debug(f"[RolloutEngine] Model call complete. Active calls: {self._active_calls}") async def wait_for_gate(self) -> None: """Wait until gate is open, then register as active call. Engines will call this at the START of get_model_response().""" + if not self._gate.is_set(): + logger.info(f"[RolloutEngine] Waiting for gate to open. Active calls: {self._active_calls}") await self._gate.wait() self._active_calls += 1 self._drained_event.clear() + logger.debug(f"[RolloutEngine] Gate passed. Active calls: {self._active_calls}") async def wait_for_drain(self) -> None: """Wait until all active model calls complete. Used during weight sync.""" + if not self._drained_event.is_set(): + logger.info(f"[RolloutEngine] Waiting for drain. Active calls: {self._active_calls}") await self._drained_event.wait() # --- Model response --- @@ -107,8 +123,9 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: await self.wait_for_gate() try: + weight_version = self.weight_version result = await self._get_model_response(messages, **kwargs) - result.weight_version = self.weight_version + result.weight_version = weight_version return result finally: self.on_model_call_complete() diff --git a/rllm/engine/rollout/tinker_engine.py b/rllm/engine/rollout/tinker_engine.py index f70cbec39..51d9cc354 100644 --- a/rllm/engine/rollout/tinker_engine.py +++ b/rllm/engine/rollout/tinker_engine.py @@ -91,6 +91,7 @@ def __init__( kwargs: Additional keyword arguments - strip_thinking_from_history: Whether to strip thinking from history (only for Qwen3Renderer) """ + super().__init__() self.base_url = base_url self.model_name = model_name self.max_prompt_length = max_prompt_length @@ -230,9 +231,11 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp prompt_ids=prompt_ids, completion_ids=response_tokens, logprobs=logprobs, + routing_matrices=None, prompt_length=_flat_token_input_length(token_input), completion_length=len(response_tokens), finish_reason=finish_reason, + metrics=None, ) @override diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py index a88794d04..4ac3d0623 100644 --- a/rllm/experimental/buffer.py +++ b/rllm/experimental/buffer.py @@ -70,6 +70,7 @@ def __init__( rs_config: RejectionSamplingConfig, episode_offload_dir: str | None = None, trajectory_group_offload_dir: str | None = None, + pbar: "tqdm | None" = None, ): self._group_size = group_size self._coordinator = coordinator @@ -78,6 +79,7 @@ def __init__( self._transform_config = transform_config self._cf_config = cf_config self._rs_config = rs_config + self._pbar = pbar # Episode offloading: pending episodes serialized to disk self._episode_offload_dir = episode_offload_dir @@ -146,7 +148,11 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: if len(self._pending[task_id]) < self._group_size: return False - # Group complete — load all episodes + # Group complete — tick progress bar + if self._pbar is not None: + self._pbar.update(1) + + # Load all episodes if self._episode_offload_dir: episodes = await self._load_pending_episodes(task_id) else: @@ -161,18 +167,12 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: traj_groups, transform_metrics = transform_episodes_to_trajectory_groups( episodes, self._transform_config, self._cf_config, ) - # Strip heavy fields from episodes for UI logging, free bulk memory - for ep in episodes: - for traj in ep.trajectories: - for step in traj.steps: - for key in _EPISODE_STRIP_KEYS: - setattr(step, key, [] if key in _EPISODE_STRIP_LIST_DEFAULTS else None) self._aggregator.record_dict(transform_metrics) # 3. Drop groups with too few trajectories before_min_traj = len(traj_groups) traj_groups = [g for g in traj_groups if len(g.trajectories) >= self._rs_config.min_trajs_per_group] - self._aggregator.record("buffer/filtered_min_trajs", before_min_traj - len(traj_groups)) + self._aggregator.record("groups/dropped_min_trajs", before_min_traj - len(traj_groups)) if not traj_groups: self._coordinator.on_group_filtered() @@ -198,7 +198,7 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: ) ] filtered_zero_adv = before_adv - len(traj_groups) - self._aggregator.record("buffer/filtered_zero_adv", filtered_zero_adv) + self._aggregator.record("groups/dropped_zero_adv", filtered_zero_adv) if not traj_groups: self._coordinator.on_group_filtered() @@ -257,14 +257,14 @@ def _record_episode_metrics(self, episodes: list[Episode]) -> None: except (TypeError, ValueError): continue - # Sequence lengths and turn counts from trajectories - for traj in ep.trajectories: - n_steps = len(traj.steps) - prompt_tokens = sum(len(s.prompt_ids) for s in traj.steps) - response_tokens = sum(len(s.response_ids) for s in traj.steps) - self._aggregator.record("episode/num_turns", n_steps) - self._aggregator.record("episode/prompt_tokens", prompt_tokens) - self._aggregator.record("episode/response_tokens", response_tokens) + # Episode-level totals across all trajectories + total_turns = sum(len(traj.steps) for traj in ep.trajectories) + total_prompt_tokens = sum(len(s.prompt_ids) for traj in ep.trajectories for s in traj.steps) + total_response_tokens = sum(len(s.response_ids) for traj in ep.trajectories for s in traj.steps) + self._aggregator.record("episode/num_turns", total_turns) + self._aggregator.record("episode/prompt_tokens", total_prompt_tokens) + self._aggregator.record("episode/response_tokens", total_response_tokens) + self._aggregator.record("episode/correct", 1.0 if ep.is_correct else 0.0) @staticmethod def _min_weight_version(episodes: list[Episode]) -> int: diff --git a/rllm/experimental/common/transform.py b/rllm/experimental/common/transform.py index 723b491af..c792c823b 100644 --- a/rllm/experimental/common/transform.py +++ b/rllm/experimental/common/transform.py @@ -150,7 +150,7 @@ def _build_trajectory_groups(episodes: list[Episode], compact_filtering_config: """ -def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "grouping") -> dict: +def _get_transform_metrics(episodes: list[Episode], groups: list[TrajectoryGroup], prefix: str = "groups") -> dict: """ Get metrics for the transformation pipeline. """ @@ -180,12 +180,11 @@ def _default_traj_grouping_hook(episodes: list[Episode], transform_config: Trans """ trajectory_groups = _build_trajectory_groups(episodes, compact_filtering_config) # part 1 reward_warnings = _validate_and_propagate_rewards(trajectory_groups, transform_config) # part 2 - - for warning in reward_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(reward_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar warnings with reward validation") + if reward_warnings: + for warning in reward_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(reward_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(reward_warnings) - LOG_N_WARNINGS} more similar reward validation warnings") return trajectory_groups @@ -194,7 +193,7 @@ def transform_episodes_to_trajectory_groups( episodes: list[Episode], transform_config: TransformConfig, compact_filtering_config: CompactFilteringConfig | None = None, - metrics_prefix: str = "grouping", + metrics_prefix: str = "groups", traj_grouping_hook: Callable[[list[Episode], TransformConfig, CompactFilteringConfig | None], list[TrajectoryGroup]] = _default_traj_grouping_hook, ) -> tuple[list[TrajectoryGroup], dict]: """ @@ -232,12 +231,11 @@ def transform_episodes_to_trajectory_groups( # Step 1: Name imputation rename_warnings = _impute_trajectory_names(episodes, transform_config) - - for warning in rename_warnings[:LOG_N_WARNINGS]: - logger.warning(warning) - - if len(rename_warnings) > LOG_N_WARNINGS: - logger.warning(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar warnings with trajectory names") + if rename_warnings: + for warning in rename_warnings[:LOG_N_WARNINGS]: + logger.debug(warning) + if len(rename_warnings) > LOG_N_WARNINGS: + logger.debug(f"Skipping {len(rename_warnings) - LOG_N_WARNINGS} more similar trajectory name warnings") # Step 2: Invoke the trajectory grouping hook groups = traj_grouping_hook(episodes, transform_config, compact_filtering_config) diff --git a/rllm/experimental/common/visualization.py b/rllm/experimental/common/visualization.py index cee8e9d01..d7d45d242 100644 --- a/rllm/experimental/common/visualization.py +++ b/rllm/experimental/common/visualization.py @@ -24,6 +24,40 @@ class VisualizationConfig: failure_style: dict[str, Any] = field(default_factory=lambda: {"fg": "red", "bold": True}) +def print_metrics_table(metrics: dict, step: int, title: str | None = None) -> None: + """Print metrics as a formatted Rich table with fallback to plain text.""" + try: + from rich.console import Console + from rich.table import Table + + table = Table(title=title or f"Step {step}", show_header=True, header_style="bold magenta") + table.add_column("Metric", style="cyan", no_wrap=False) + table.add_column("Value", justify="right", style="green") + + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + table.add_row(key, value_str) + + Console().print(table) + except ImportError: + print(f"\n{title or f'Step {step}'}") + print("=" * 60) + for key, value in sorted(metrics.items()): + if isinstance(value, float): + value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" + elif isinstance(value, int): + value_str = str(value) + else: + value_str = str(value) + print(f"{key:40s} {value_str:>15s}") + print("=" * 60) + + def colorful_print(string: str, *args, **kwargs) -> None: end = kwargs.pop("end", "\n") print(click.style(string, *args, **kwargs), end=end, flush=True) diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 72b71796e..b60999a6c 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -105,6 +105,7 @@ async def initialize_pool(self): assert self.executor is not None, "executor is not initialized" if self.workflow_queue is not None: return + logger.info(f"[WorkflowEngine] Initializing pool with {self.n_parallel_tasks} workflows") self.workflow_queue = asyncio.Queue(maxsize=self.n_parallel_tasks) for i in range(self.n_parallel_tasks): workflow = self.workflow_cls( @@ -115,6 +116,7 @@ async def initialize_pool(self): ) assert workflow.is_multithread_safe(), "Workflows must contain only thread-save environments" self.workflow_queue.put_nowait(workflow) + logger.info(f"[WorkflowEngine] Pool initialized. Queue size: {self.workflow_queue.qsize()}") async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs) -> tuple[str, int, int, Episode]: """Process a single task rollout with retry logic based on termination reasons. @@ -133,10 +135,12 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i Exception: If task fails permanently after retry_limit attempts and raise_on_error is True. """ assert self.workflow_queue is not None, "workflow_queue is not initialized" + logger.debug(f"[WorkflowEngine] Waiting for workflow from queue. Available: {self.workflow_queue.qsize()}") workflow = await self.workflow_queue.get() try: for retry_attempt in range(1, self.retry_limit + 1): uid = f"{task_id}:{rollout_idx}" + logger.debug(f"[WorkflowEngine] [{uid}] Starting attempt {retry_attempt}/{self.retry_limit}") workflow.reset(task=task, uid=uid) episode = await workflow.run_with_termination_handling(task=task, uid=uid, **kwargs) @@ -153,9 +157,8 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i elif len(traj.steps) > 0: reward = f"{traj.steps[-1].reward:.1f}" reward_strs.append(f"{traj.name}: {reward}") - colorful_print( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}", - fg="green" if episode.is_correct else "yellow", + logger.debug( + f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}" ) if episode.termination_reason != TerminationReason.ERROR: @@ -163,14 +166,14 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i error_tb = episode.info.get("error", {}).get("traceback") if error_tb: - print(error_tb) + logger.error(f"[WorkflowEngine] [{uid}] Error on attempt {retry_attempt}/{self.retry_limit}:\n{error_tb}") if retry_attempt < self.retry_limit: - print(f"[{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") + logger.warning(f"[WorkflowEngine] [{uid}] Rollout failed on attempt {retry_attempt}/{self.retry_limit}, retrying...") continue if not self.raise_on_error: - print(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") + logger.error(f"[WorkflowEngine] [{uid}] Rollout failed permanently after {self.retry_limit} attempts.") else: raise Exception(f"[{uid}] Rollout failed permanently after {self.retry_limit} attempts.") @@ -178,6 +181,7 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i finally: await self.workflow_queue.put(workflow) + logger.debug(f"[WorkflowEngine] Returned workflow to queue. Available: {self.workflow_queue.qsize()}") async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = None, is_validation: bool = False, **kwargs) -> list[Episode]: """Run asynchronous workflow execution with retry logic for multiple tasks. diff --git a/rllm/experimental/metrics.py b/rllm/experimental/metrics.py index 2ee4734d6..55cfa8cd9 100644 --- a/rllm/experimental/metrics.py +++ b/rllm/experimental/metrics.py @@ -13,18 +13,18 @@ # Keys that should be summed rather than averaged. _SUM_KEYS: set[str] = { - "grouping/num_trajs_before_filter", - "grouping/num_trajs_after_filter", - "grouping/num_groups", - "buffer/filtered_min_trajs", - "buffer/filtered_zero_adv", + "groups/num_trajs_before_filter", + "groups/num_trajs_after_filter", + "groups/num_groups", + "groups/dropped_min_trajs", + "groups/dropped_zero_adv", } # Prefixes where "last value" is the correct reduction. _LAST_PREFIXES: tuple[str, ...] = ( "time/", + "train/", "progress/", - "optim/", "async/", ) diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index 86a0dfc6a..066cfc5ac 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -30,12 +30,11 @@ class SyncCoordinator: def __init__(self, config: SyncCoordinatorConfig): self.config = config - self._policy_version: int = 0 - self._dispatched_since_sync: int = 0 # groups dispatched in current sync window + self._weight_version: int = 0 + self._quota_used: int = 0 # groups counting toward current sync window quota (includes carryover) self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered self._steps_since_sync: int = 0 self._total_syncs: int = 0 - self._total_groups_filtered: int = 0 # Throttle — blocks generation when dispatched_since_sync >= max_rollout_quota self._throttle_event: asyncio.Event = asyncio.Event() @@ -46,16 +45,16 @@ def __init__(self, config: SyncCoordinatorConfig): self._generation_paused.set() @property - def policy_version(self) -> int: - return self._policy_version + def weight_version(self) -> int: + return self._weight_version # --- Throttle --- def on_group_dispatched(self) -> None: """Generation loop dispatched one prompt (n rollouts).""" - self._dispatched_since_sync += 1 + self._quota_used += 1 self._in_flight += 1 - if self._dispatched_since_sync >= self.config.max_rollout_quota: + if self._quota_used >= self.config.max_rollout_quota: self._throttle_event.clear() def on_group_consumed(self) -> None: @@ -63,8 +62,7 @@ def on_group_consumed(self) -> None: self._in_flight = max(0, self._in_flight - 1) def on_group_filtered(self) -> None: - """Accumulator filtered out a group. Decrements in-flight count and tracks stats.""" - self._total_groups_filtered += 1 + """Accumulator filtered out a group. Decrements in-flight count.""" self._in_flight = max(0, self._in_flight - 1) async def wait_for_throttle(self) -> None: @@ -73,7 +71,7 @@ async def wait_for_throttle(self) -> None: def has_quota(self) -> bool: """Whether the generation loop can dispatch another group.""" - return self._dispatched_since_sync < self.config.max_rollout_quota + return self._quota_used < self.config.max_rollout_quota # --- Weight sync --- @@ -84,13 +82,13 @@ def should_sync(self) -> bool: return self._steps_since_sync >= self.config.trigger_parameter_sync_step def on_sync_complete(self) -> None: - self._policy_version += 1 + self._weight_version += 1 self._steps_since_sync = 0 self._total_syncs += 1 # Reset dispatch window. In-flight items span the sync boundary — # they were dispatched with old weights and count toward the new window. - self._dispatched_since_sync = self._in_flight - if self._dispatched_since_sync < self.config.max_rollout_quota: + self._quota_used = self._in_flight + if self._quota_used < self.config.max_rollout_quota: self._throttle_event.set() # --- Generation pause (for validation / weight sync if partial_rollout is False) --- @@ -106,11 +104,11 @@ async def wait_for_generation_allowed(self) -> None: def stats(self) -> dict: return { - "async/policy_version": self._policy_version, - "async/dispatched_since_sync": self._dispatched_since_sync, + "async/weight_version": self._weight_version, + "async/dispatched_since_sync": self._quota_used - self._in_flight, + "async/quota_used": self._quota_used, "async/in_flight_groups": self._in_flight, "async/steps_since_sync": self._steps_since_sync, "async/max_rollout_quota": self.config.max_rollout_quota, "async/total_syncs": self._total_syncs, - "async/total_groups_filtered": self._total_groups_filtered, } diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index e0d538e46..dbf973699 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -1,4 +1,5 @@ import asyncio +import logging import time import uuid from abc import ABC, abstractmethod @@ -8,8 +9,11 @@ from pprint import pprint from typing import Any, Literal +logger = logging.getLogger(__name__) + import numpy as np from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset @@ -34,7 +38,7 @@ _default_traj_grouping_hook, transform_episodes_to_trajectory_groups, ) -from rllm.experimental.common.visualization import visualize_trajectory_last_steps +from rllm.experimental.common.visualization import print_metrics_table, visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine from rllm.experimental.buffer import TrajectoryGroupBuffer from rllm.experimental.metrics import MetricsAggregator @@ -54,7 +58,7 @@ class TrainerState: epoch: int = 0 total_steps: int = 0 is_training: bool = True - policy_version: int = 0 + weight_version: int = 0 # For timing and metrics timing_dict: dict = field(default_factory=dict) metrics: dict = field(default_factory=dict) @@ -138,6 +142,8 @@ def __init__( # Extract the TrajectoryGroup-specific estimator from kwargs self.traj_group_adv_estimator_map = traj_group_adv_estimator_map or {} + # TODO(kylemontgomery1): disaggregate UnitifiedTrainer.__init__ from engine/infra setup + self.backend = backend_cls(config=config, **(backend_args or {})) self._validate_and_setup_configs() @@ -315,6 +321,8 @@ def _setup_logging(self): # Main training loop methods # ========================================================================= + # TODO(kylemontgomery1): better seperation of on policy vs fully async training code + def fit(self): """Main training loop (sync entry point).""" asyncio.run(self.fit_async()) @@ -335,8 +343,7 @@ async def fit_async(self) -> None: await self.backend.on_train_start(trainer_state) if self.rllm_config.trainer.get("val_before_train", True): - val_metrics = await self._validate_async(trainer_state) - pprint(f"Initial validation metrics: {val_metrics}") + await self._validate_async(trainer_state) if self.rllm_config.trainer.get("val_only", False): return @@ -358,6 +365,7 @@ async def _fit_async(self, trainer_state: TrainerState) -> None: async def _fit_on_policy(self, trainer_state: TrainerState) -> None: """Synchronous training loop (the most vanilla, standalone case that does not support minibatching or off-policy training).""" + # TODO(kylemontgomery1): dataloader should be backend-agnostic train_dataloader: Iterable = self.backend.get_dataloader(self.train_dataset, trainer_state) break_via_total_batches = False # used to break the training loop via the `total_batches` parameter use_total_batches = self.rllm_config.trainer.get("total_batches") is not None and self.rllm_config.trainer.total_batches > 0 @@ -384,6 +392,7 @@ async def _fit_on_policy(self, trainer_state: TrainerState) -> None: await self._train_batch_async(batch, trainer_state) await self.backend.on_batch_end(trainer_state) + print_metrics_table(trainer_state.metrics, trainer_state.global_step) self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, @@ -406,13 +415,13 @@ async def _fit_on_policy(self, trainer_state: TrainerState) -> None: # final validation after training if self.rllm_config.trainer.test_freq > 0: - val_metrics = await self._validate_async(trainer_state) - pprint(f"Final validation metrics: {val_metrics}") + await self._validate_async(trainer_state) async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> None: """Train a batch (async implementation).""" self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=trainer_state.epoch) + # TODO(kylemontgomery1): episode generation should be backend-agnostic # stage 1: generate episodes (async) and collect metrics (sync) trainer_state.episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=False) if not trainer_state.has_episodes: @@ -446,6 +455,7 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N await self.backend.process_backend_batch(trainer_state) assert trainer_state.has_backend_batch, "Backend batch is not transformed or processed successfully" + # TODO(kylemontgomery1): compute advantages should be backend-agnostic # stage 6: compute advantages (async) await self.backend.compute_advantages(trainer_state, self.algorithm_config) @@ -508,14 +518,21 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: else: trainer_state.total_steps = len(train_dataloader) * self.rllm_config.trainer.total_epochs - gen_task = asyncio.create_task(self._generation_loop(trainer_state, buffer, coordinator)) - await self._training_loop(trainer_state, buffer, coordinator, aggregator) - if not gen_task.done(): - gen_task.cancel() - try: - await gen_task - except asyncio.CancelledError: - pass + total_tasks = len(train_dataloader) * self.rllm_config.trainer.total_epochs + pbar = tqdm(total=total_tasks, desc="Tasks", unit="task") + buffer._pbar = pbar + + try: + gen_task = asyncio.create_task(self._generation_loop(trainer_state, buffer, coordinator)) + await self._training_loop(trainer_state, buffer, coordinator, aggregator) + if not gen_task.done(): + gen_task.cancel() + try: + await gen_task + except asyncio.CancelledError: + pass + finally: + pbar.close() async def _generation_loop( self, trainer_state: TrainerState, buffer: TrajectoryGroupBuffer, coordinator: SyncCoordinator, @@ -575,9 +592,12 @@ async def _training_loop( buffer_wait_time = 0.0 done = False + buffered = buffer._queue.qsize() + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: waiting for {mini_batch_size} task batches ({num_fwd_bwd_passes} fwd-bwd passes x {fwd_bwd_group_size} groups), {buffered} buffered") + # 1. Pull mini_batch_size task batches total, split into # num_fwd_bwd_passes forward-backward passes of fwd_bwd_group_size each. - for _ in range(num_fwd_bwd_passes): + for pass_idx in range(num_fwd_bwd_passes): chunk_groups: list[TrajectoryGroup] = [] for _ in range(fwd_bwd_group_size): @@ -604,6 +624,7 @@ async def _training_loop( trainer_state.trajectory_groups = chunk_groups if trainer_state.has_trajectory_groups: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: fwd-bwd pass {pass_idx + 1}/{num_fwd_bwd_passes} ({len(chunk_groups)} groups)") await self.backend.on_batch_start(trainer_state) trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) await self.backend.process_backend_batch(trainer_state) @@ -614,29 +635,35 @@ async def _training_loop( # Only run optimizer step on a full batch if groups_consumed < mini_batch_size: + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: incomplete batch ({groups_consumed}/{mini_batch_size}), stopping") break # 2. Optimizer step + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: optimizer step") await self.backend.update_policy(trainer_state) - # 3. Weight sync + # 3. Capture pre-sync metrics (before weight sync resets coordinator state) + staleness_values = [coordinator.weight_version - v for v in weight_versions] + aggregator.record("async/staleness_mean", float(np.mean(staleness_values))) + aggregator.record("async/staleness_min", float(np.min(staleness_values))) + aggregator.record("async/staleness_max", float(np.max(staleness_values))) + aggregator.record("async/groups_consumed", groups_consumed) + aggregator.record("time/buffer_wait", buffer_wait_time) + pre_sync_coordinator_stats = coordinator.stats() + pre_sync_buffer_stats = buffer.stats() + + # 4. Weight sync coordinator.on_training_step_complete() sync_time = 0.0 if coordinator.should_sync(): + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: triggering weight sync") t0 = time.perf_counter() await self._perform_weight_sync(trainer_state, coordinator, rollout_engine) sync_time = time.perf_counter() - t0 - - # 4. Record training-loop metrics to aggregator - staleness_values = [coordinator.policy_version - v for v in weight_versions] - aggregator.record("async/staleness_mean", float(np.mean(staleness_values))) - aggregator.record("async/staleness_min", float(np.min(staleness_values))) - aggregator.record("async/staleness_max", float(np.max(staleness_values))) - aggregator.record("async/groups_consumed", groups_consumed) - aggregator.record("time/step", time.perf_counter() - step_start) - aggregator.record("time/buffer_wait", buffer_wait_time) + logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: weight sync complete ({sync_time:.2f}s)") if sync_time > 0: aggregator.record("time/weight_sync", sync_time) + aggregator.record("time/step", time.perf_counter() - step_start) # Set all trajectory groups and stripped episodes for visualization/logging trainer_state.trajectory_groups = all_trajectory_groups @@ -650,14 +677,20 @@ async def _training_loop( show_workflow_metadata=True, ) - # 5. on_batch_end writes backend metrics (progress, optim, timing) to trainer_state.metrics - await self.backend.on_batch_end(trainer_state) - - # 6. Flush aggregator and merge snapshots into trainer_state.metrics for logging + # 5. Flush aggregator and merge pre-sync snapshots into trainer_state.metrics trainer_state.metrics.update(aggregator.flush()) - trainer_state.metrics.update(buffer.stats()) - trainer_state.metrics.update(coordinator.stats()) + trainer_state.metrics.update(pre_sync_buffer_stats) + trainer_state.metrics.update(pre_sync_coordinator_stats) + + # 6. Compute derived metrics + step_time = trainer_state.metrics.get("time/step", 1.0) + trainer_state.metrics["async/trainer_idle_ratio"] = buffer_wait_time / max(step_time, 1e-9) + + # 7. on_batch_end writes backend metrics (progress, optim, timing) + await self.backend.on_batch_end(trainer_state) + # 7. Print and log + print_metrics_table(trainer_state.metrics, trainer_state.global_step) self.logger.log( data=trainer_state.metrics, step=trainer_state.global_step, @@ -698,10 +731,10 @@ async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: S coordinator.pause_generation() await self._wait_for_drain() - trainer_state.policy_version = coordinator.policy_version + 1 + trainer_state.weight_version = coordinator.weight_version + 1 await self.backend.on_policy_updated(trainer_state) if rollout_engine is not None: - rollout_engine.weight_version = trainer_state.policy_version + rollout_engine.weight_version = trainer_state.weight_version coordinator.on_sync_complete() if self.async_config.partial_rollout: @@ -745,7 +778,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for batch in val_dataloader: # Generate episodes and transform to trajectory groups val_episodes = await self.backend.generate_episodes(batch, agent_workflow_engine=self.agent_workflow_engine, is_validation=True) - val_trajectory_groups, transform_metrics = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) + val_trajectory_groups, _ = transform_episodes_to_trajectory_groups(val_episodes, self.transform_config, self.cf_config, traj_grouping_hook=self.traj_grouping_hook) reward_metrics = collect_reward_and_advantage_from_trajectory_groups(val_trajectory_groups, self.algorithm_config, collect_advantage=False) is_correct_lst.extend([episode.is_correct for episode in val_episodes]) @@ -758,7 +791,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: for key, value in episode.metrics.items(): workflow_metrics_by_source[data_source][key].append(float(value)) - for key, value in (transform_metrics | reward_metrics).items(): + for key, value in reward_metrics.items(): val_metrics[f"val/{key}"].append(value) test_end = time.perf_counter() @@ -788,6 +821,7 @@ async def _validate_async(self, trainer_state: TrainerState) -> dict: # post-process the val metrics to reduce any "list values" into scalars reduce_metrics_lists(val_metrics) + print_metrics_table(val_metrics, trainer_state.global_step, title="Validation") self.logger.log(data=val_metrics, step=trainer_state.global_step) await self.backend.on_validation_end(trainer_state) return val_metrics diff --git a/rllm/parser/__init__.py b/rllm/parser/__init__.py index 5abeef15c..8b5e0b993 100644 --- a/rllm/parser/__init__.py +++ b/rllm/parser/__init__.py @@ -12,21 +12,6 @@ ] -def __getattr__(name): - _chat_template_classes = { - "ChatTemplateParser", - "DeepseekQwenChatTemplateParser", - "LlamaChatTemplateParser", - "QwenChatTemplateParser", - } - if name in _chat_template_classes: - import importlib - - mod = importlib.import_module("rllm.parser.chat_template_parser") - return getattr(mod, name) - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - PARSER_REGISTRY = { "r1": R1ToolParser, "qwen": QwenToolParser, @@ -38,9 +23,20 @@ def get_tool_parser(parser_name: str) -> type[ToolParser]: return PARSER_REGISTRY[parser_name] +_CHAT_TEMPLATE_CLASSES = { + "ChatTemplateParser", + "DeepseekQwenChatTemplateParser", + "LlamaChatTemplateParser", + "QwenChatTemplateParser", +} + + def __getattr__(name): + if name in _CHAT_TEMPLATE_CLASSES: + import importlib + mod = importlib.import_module("rllm.parser.chat_template_parser") + return getattr(mod, name) if name == "TinkerChatTemplateParser": from rllm.parser.tinker_parser import TinkerChatTemplateParser - return TinkerChatTemplateParser raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/rewards/countdown_reward.py b/rllm/rewards/countdown_reward.py index ccdf6157a..093d35f95 100644 --- a/rllm/rewards/countdown_reward.py +++ b/rllm/rewards/countdown_reward.py @@ -1,9 +1,12 @@ +import logging import random import re from rllm import Action from rllm.rewards.reward_types import RewardOutput +logger = logging.getLogger(__name__) + def extract_solution(solution_str): """Extract the equation from the solution string.""" @@ -72,20 +75,17 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, do_print = random.randint(1, 64) == 1 if do_print: - print("--------------------------------") - print(f"Target: {target} | Numbers: {numbers}") - print(f"Extracted equation: {equation}") - print(f"Solution string: {solution_str}") + logger.debug(f"Target: {target} | Numbers: {numbers} | Equation: {equation} | Solution: {solution_str}") if equation is None: if do_print: - print("No equation found") + logger.debug("No equation found") return 0 # Validate equation uses correct numbers if not validate_equation(equation, numbers): if do_print: - print("Invalid equation") + logger.debug("Invalid equation") return format_score # Evaluate equation @@ -93,20 +93,20 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.1, result = evaluate_equation(equation) if result is None: if do_print: - print("Could not evaluate equation") + logger.debug("Could not evaluate equation") return format_score if abs(result - target) < 1e-5: # Account for floating point precision if do_print: - print(f"Correct equation: {equation} = {result}") + logger.debug(f"Correct equation: {equation} = {result}") return score else: if do_print: - print(f"Wrong result: equation = {result}, target = {target}") + logger.debug(f"Wrong result: equation = {result}, target = {target}") return format_score except Exception: if do_print: - print("Error evaluating equation") + logger.debug("Error evaluating equation") return format_score diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 2ae85a2a8..2bf3b143b 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -443,10 +443,6 @@ async def on_batch_end(self, trainer_state: TrainerState) -> None: learning_rate = trainer_state.extra_info.get("scheduled_learning_rate", self.learning_rate) update_training_metrics(trainer_state, learning_rate, trainer_state.total_steps) - # Print metrics table - if trainer_state.metrics: - print_metrics_table(trainer_state.metrics, trainer_state.global_step) - async def on_epoch_start(self, trainer_state: TrainerState) -> None: """Called at the start of an epoch.""" logger.info(f"Starting epoch {trainer_state.epoch}") diff --git a/rllm/trainer/tinker/tinker_metrics_utils.py b/rllm/trainer/tinker/tinker_metrics_utils.py index 3976081f3..805ed9ad4 100644 --- a/rllm/trainer/tinker/tinker_metrics_utils.py +++ b/rllm/trainer/tinker/tinker_metrics_utils.py @@ -5,60 +5,12 @@ import tinker import torch +from rllm.experimental.common.visualization import print_metrics_table # noqa: F401 (re-export) from rllm.experimental.unified_trainer import TrainerState logger = logging.getLogger(__name__) -def print_metrics_table(metrics: dict, step: int): - """ - Print metrics as a formatted table (similar to tinker_cookbook). - - Args: - metrics: Dictionary of metrics - step: Current step number - """ - try: - from rich.console import Console - from rich.table import Table - - console = Console() - - # Create table - table = Table(title=f"Step {step}", show_header=True, header_style="bold magenta") - table.add_column("Metric", style="cyan", no_wrap=False) - table.add_column("Value", justify="right", style="green") - - # Sort metrics by key for consistent ordering - sorted_metrics = sorted(metrics.items()) - - for key, value in sorted_metrics: - # Format value based on type - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - - table.add_row(key, value_str) - - console.print(table) - - except ImportError: - # Fallback to simple text table if rich is not available - print(f"\nStep {step}") - print("=" * 60) - for key, value in sorted(metrics.items()): - if isinstance(value, float): - value_str = f"{value:.6f}" if abs(value) < 1000 else f"{value:.2f}" - elif isinstance(value, int): - value_str = str(value) - else: - value_str = str(value) - print(f"{key:40s} {value_str:>15s}") - print("=" * 60) - def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training_logprobs: list[torch.Tensor]) -> dict: """ @@ -102,10 +54,10 @@ def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training perplexity = torch.exp(torch.tensor(entropy_sample)).item() return { - "optim/kl_sample_train_v1": kl_sample_train_v1, - "optim/kl_sample_train_v2": kl_sample_train_v2, - "optim/entropy": entropy_sample, - "optim/perplexity": perplexity, + "train/kl_sample_train_v1": kl_sample_train_v1, + "train/kl_sample_train_v2": kl_sample_train_v2, + "train/entropy": entropy_sample, + "train/perplexity": perplexity, } @@ -125,7 +77,7 @@ def update_training_metrics(trainer_state: TrainerState, learning_rate: float, t { "progress/batch": trainer_state.global_step, "progress/epoch": trainer_state.epoch, - "optim/lr": learning_rate, + "progress/lr": learning_rate, } ) diff --git a/rllm/trainer/tinker/tinker_policy_trainer.py b/rllm/trainer/tinker/tinker_policy_trainer.py index 57d7bd6ca..30f1caa6f 100644 --- a/rllm/trainer/tinker/tinker_policy_trainer.py +++ b/rllm/trainer/tinker/tinker_policy_trainer.py @@ -255,12 +255,18 @@ async def forward_backward_from_trajectory_groups( # Wait for completion and extract logprobs fwd_bwd_results = await asyncio.gather(*fwd_bwd_futures) - # Extract training logprobs from loss_fn_outputs + # Extract training logprobs and server-side metrics from results training_logprobs = [] for fwd_bwd_result in fwd_bwd_results: for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + # Capture server-side metrics (e.g. loss) under train/ prefix + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics @@ -335,6 +341,11 @@ async def fused_forward_backward_and_optim_step( for output in fwd_bwd_result.loss_fn_outputs: logprobs = output["logprobs"].to_torch() training_logprobs.append(logprobs) + if fwd_bwd_result.metrics: + for k, v in fwd_bwd_result.metrics.items(): + if k.startswith("clock_cycle"): + continue + adv_metrics[f"train/{k.replace(':', '/')}"] = v return training_datums, training_logprobs, adv_metrics, scheduled_learning_rate diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index 036a62b27..7a0d73862 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -166,12 +166,12 @@ def transform_trajectory_groups_to_datums( datums = [] # step 2: iterate over all steps and build the Tinker Datum objects - datums_per_traj = [] + seqs_per_traj = [] seq_lengths = [] for group in trajectory_groups: for trajectory in group.trajectories: traj_datums = trajectory_to_datums(trajectory, router_replay=algorithm_config.router_replay) - datums_per_traj.append(len(traj_datums)) + seqs_per_traj.append(len(traj_datums)) for d in traj_datums: seq_lengths.append(d.model_input.length) if algorithm_config.estimator_map: @@ -179,13 +179,13 @@ def transform_trajectory_groups_to_datums( else: datums.extend(traj_datums) - if datums_per_traj: + if seqs_per_traj: import numpy as _np - adv_metrics["train/datums_per_traj/mean"] = _np.mean(datums_per_traj) - adv_metrics["train/datums_per_traj/min"] = _np.min(datums_per_traj) - adv_metrics["train/datums_per_traj/max"] = _np.max(datums_per_traj) - adv_metrics["train/seq_length/mean"] = _np.mean(seq_lengths) - adv_metrics["train/seq_length/min"] = _np.min(seq_lengths) - adv_metrics["train/seq_length/max"] = _np.max(seq_lengths) + adv_metrics["batch/seqs_per_traj/mean"] = _np.mean(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/min"] = _np.min(seqs_per_traj) + adv_metrics["batch/seqs_per_traj/max"] = _np.max(seqs_per_traj) + adv_metrics["batch/seq_length/mean"] = _np.mean(seq_lengths) + adv_metrics["batch/seq_length/min"] = _np.min(seq_lengths) + adv_metrics["batch/seq_length/max"] = _np.max(seq_lengths) return (datums if not algorithm_config.estimator_map else datums_dict), adv_metrics From 0f01be7ce9ae8f395d29ed707193c9a176564a4a Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 12:17:39 -0700 Subject: [PATCH 14/21] revert engine/rollout to main, restore experimental/rollout engines Move enhanced rollout engines (tinker, verl, completer, types) back to rllm/experimental/rollout/ and revert rllm/engine/rollout/ to match main. Fix import paths in experimental code and tinker backend/transform. Co-Authored-By: Claude Opus 4.6 (1M context) --- rllm/engine/rollout/__init__.py | 23 +- rllm/engine/rollout/openai_engine.py | 70 +--- rllm/engine/rollout/rollout_engine.py | 92 +---- rllm/engine/rollout/tinker_engine.py | 390 ++++++++++-------- rllm/engine/rollout/verl_engine.py | 64 +-- .../engine/unified_workflow_engine.py | 2 +- rllm/experimental/protocol.py | 2 +- rllm/experimental/rollout/__init__.py | 25 +- .../rollout/completer.py | 4 +- rllm/experimental/rollout/rollout_engine.py | 81 +++- rllm/experimental/rollout/tinker_engine.py | 271 ++++++++++++ .../{engine => experimental}/rollout/types.py | 0 rllm/experimental/rollout/verl_engine.py | 133 ++++++ .../test_examples/opsd/math_opsd_workflow.py | 4 +- rllm/experimental/unified_trainer.py | 2 +- rllm/experimental/verl/verl_backend.py | 2 +- rllm/trainer/tinker/tinker_backend.py | 2 +- rllm/trainer/tinker/transform.py | 4 +- 18 files changed, 746 insertions(+), 425 deletions(-) rename rllm/{engine => experimental}/rollout/completer.py (97%) create mode 100644 rllm/experimental/rollout/tinker_engine.py rename rllm/{engine => experimental}/rollout/types.py (100%) create mode 100644 rllm/experimental/rollout/verl_engine.py diff --git a/rllm/engine/rollout/__init__.py b/rllm/engine/rollout/__init__.py index 471682f61..47995ca85 100644 --- a/rllm/engine/rollout/__init__.py +++ b/rllm/engine/rollout/__init__.py @@ -1,26 +1,11 @@ -from typing import TYPE_CHECKING - +# Avoid importing concrete engines at module import time to prevent circular imports from .rollout_engine import ModelOutput, RolloutEngine -from .types import TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput, VerlTokenInput, VerlTokenOutput - -if TYPE_CHECKING: - from .tinker_engine import TinkerEngine - from .verl_engine import VerlEngine __all__ = [ "ModelOutput", "RolloutEngine", "OpenAIEngine", - "TinkerEngine", "VerlEngine", - # Token types - "TokenInput", - "TokenOutput", - "TinkerTokenInput", - "TinkerTokenOutput", - "VerlTokenInput", - "VerlTokenOutput", - "Tokenizer", ] @@ -29,10 +14,6 @@ def __getattr__(name): from .openai_engine import OpenAIEngine as _OpenAIEngine return _OpenAIEngine - if name == "TinkerEngine": - from .tinker_engine import TinkerEngine as _TinkerEngine - - return _TinkerEngine if name == "VerlEngine": try: from .verl_engine import VerlEngine as _VerlEngine @@ -40,4 +21,4 @@ def __getattr__(name): return _VerlEngine except Exception: raise AttributeError(name) from None - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + raise AttributeError(name) diff --git a/rllm/engine/rollout/openai_engine.py b/rllm/engine/rollout/openai_engine.py index ec95588b3..60c130505 100644 --- a/rllm/engine/rollout/openai_engine.py +++ b/rllm/engine/rollout/openai_engine.py @@ -3,7 +3,6 @@ import logging import os from io import BytesIO -import json import openai from PIL import Image @@ -11,7 +10,7 @@ from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine from rllm.globals import THOUGHT_DELIMITER_END, THOUGHT_DELIMITER_START from rllm.parser import ChatTemplateParser -from rllm.tools.tool_base import Tool, ToolCall, ToolOutput +from rllm.tools.tool_base import Tool from rllm.workflows import TerminationEvent, TerminationReason @@ -81,57 +80,6 @@ def _prepare_max_tokens_param(self, sampling_params: dict, prompt_length: int = return {"max_tokens": max_tokens} - def _convert_openai_to_tool_calls(self, tool_calls: list[dict] | None) -> list[ToolCall]: - """Convert OpenAI tool calls to internal ToolCall objects.""" - if not tool_calls: - return [] - processed_tool_calls: list[ToolCall] = [] - for tool_call in tool_calls: - try: - arguments = json.loads(tool_call.function.arguments) - except Exception as e: - print(f"Error parsing tool call: {tool_call.function.arguments}, error: {e}") - continue - processed_tool_calls.append( - ToolCall( - id=tool_call.id, - name=tool_call.function.name, - arguments=arguments, - ) - ) - return processed_tool_calls - - def _convert_tool_calls_to_openai(self, tool_calls: list[ToolCall] | None) -> list[dict] | None: - """Convert internal ToolCall objects to OpenAI format using base class method.""" - if not tool_calls: - return None - return [tool_call.to_openai_format() if isinstance(tool_call, ToolCall) else tool_call for tool_call in tool_calls] - - def _convert_tool_outputs_to_openai(self, tool_outputs: list[ToolOutput] | None) -> list[dict] | None: - """Convert internal ToolOutput objects to OpenAI format using base class method.""" - if not tool_outputs: - return None - return [tool_output.to_openai_format() if isinstance(tool_output, ToolOutput) else tool_output for tool_output in tool_outputs] - - def _prepare_messages_for_openai(self, messages: list[dict]) -> list[dict]: - """Convert messages from internal format to OpenAI format.""" - openai_messages = [] - for msg in messages: - role = msg.get("role") - if role == "assistant": - openai_msg = {"role": "assistant", "content": msg.get("content")} - if "tool_calls" in msg and msg["tool_calls"]: - openai_msg["tool_calls"] = self._convert_tool_calls_to_openai(msg["tool_calls"]) - openai_messages.append(openai_msg) - elif role == "tool": - assert "tool_outputs" in msg, "Tool message must contain tool_outputs" - tool_msgs = self._convert_tool_outputs_to_openai(msg["tool_outputs"]) - if tool_msgs: - openai_messages.extend(tool_msgs) - else: - openai_messages.append(msg) - return openai_messages - async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: kwargs.pop("application_id", None) kwargs.pop("validate", None) @@ -142,22 +90,16 @@ async def chat_completion(self, messages: list[dict], **kwargs) -> ModelOutput: sampling_params.update(kwargs) create_params = self._prepare_max_tokens_param(sampling_params) - sampling_params.update(create_params) - - tools = sampling_params.pop("tools", self.tools) - if tools: - tools = [tool.json if isinstance(tool, Tool) else tool for tool in tools] - - # Convert messages from to OpenAI format - openai_messages = self._prepare_messages_for_openai(messages) + converted_messages = self._convert_messages_to_openai_format(messages) retries = self.api_retries while retries > 0: try: - response = await self.client.chat.completions.create(model=self.model, messages=openai_messages, tools=tools, timeout=3600, **sampling_params) + response = await self.client.chat.completions.create(model=self.model, messages=converted_messages, timeout=3600, **create_params, **sampling_params) + content = response.choices[0].message.content reasoning = response.choices[0].message.reasoning if hasattr(response.choices[0].message, "reasoning") and isinstance(response.choices[0].message.reasoning, str) else "" - tool_calls = self._convert_openai_to_tool_calls(response.choices[0].message.tool_calls) + tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") and isinstance(response.choices[0].message.tool_calls, list) else [] # Build text with reasoning if available, otherwise use content if reasoning: @@ -283,7 +225,7 @@ async def completion(self, prompt: str | list[int], **kwargs) -> ModelOutput: print(f"Error: {e}, retrying...") await asyncio.sleep(1) - async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: if self._use_chat_completions: accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) if accumulate_reasoning: diff --git a/rllm/engine/rollout/rollout_engine.py b/rllm/engine/rollout/rollout_engine.py index 4e1b98ab9..7f3895429 100644 --- a/rllm/engine/rollout/rollout_engine.py +++ b/rllm/engine/rollout/rollout_engine.py @@ -1,13 +1,7 @@ -import asyncio -import logging from dataclasses import dataclass -from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput -from rllm.parser import ChatTemplateParser from rllm.tools.tool_base import ToolCall -logger = logging.getLogger(__name__) - @dataclass class ModelOutput: @@ -15,17 +9,14 @@ class ModelOutput: content: str | None = None reasoning: str | None = None tool_calls: list[ToolCall] | None = None - prompt_ids: TokenInput | None = None + prompt_ids: list[int] | None = None completion_ids: list[int] | None = None multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids - routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None - weight_version: int | None = None # policy version at time of generation - metrics: dict | None = None # per-turn server metrics (e.g. ttft, queue durations) def to_dict(self): return { @@ -41,8 +32,6 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, - "weight_version": self.weight_version, - "metrics": self.metrics, } @classmethod @@ -60,90 +49,15 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), - weight_version=data.get("weight_version"), - metrics=data.get("metrics"), ) class RolloutEngine: - chat_parser: ChatTemplateParser | None = None - tokenizer: Tokenizer | None = None - is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks - def __init__(self, *args, **kwargs): - # Gate mechanism for pausing model calls during weight sync - self._gate: asyncio.Event = asyncio.Event() - self._gate.set() # open by default - self._active_calls: int = 0 - self._drained_event: asyncio.Event = asyncio.Event() - self._drained_event.set() # initially drained (no active calls) - self.weight_version: int = 0 - - # --- Gate mechanism --- - - def close_gate(self) -> None: - """Close the gate. New model calls will block at wait_for_gate().""" - logger.info(f"[RolloutEngine] Closing gate. Active calls: {self._active_calls}") - self._gate.clear() - - def open_gate(self) -> None: - """Open the gate, releasing any blocked model calls.""" - logger.info(f"[RolloutEngine] Opening gate. Active calls: {self._active_calls}") - self._gate.set() - - def on_model_call_complete(self) -> None: - """Unregister active call. Engines will call this at the END of get_model_response().""" - self._active_calls -= 1 - if self._active_calls <= 0: - self._active_calls = 0 - self._drained_event.set() - logger.debug("[RolloutEngine] All active calls drained.") - else: - logger.debug(f"[RolloutEngine] Model call complete. Active calls: {self._active_calls}") - - async def wait_for_gate(self) -> None: - """Wait until gate is open, then register as active call. Engines will call this at the START of get_model_response().""" - if not self._gate.is_set(): - logger.info(f"[RolloutEngine] Waiting for gate to open. Active calls: {self._active_calls}") - await self._gate.wait() - self._active_calls += 1 - self._drained_event.clear() - logger.debug(f"[RolloutEngine] Gate passed. Active calls: {self._active_calls}") - - async def wait_for_drain(self) -> None: - """Wait until all active model calls complete. Used during weight sync.""" - if not self._drained_event.is_set(): - logger.info(f"[RolloutEngine] Waiting for drain. Active calls: {self._active_calls}") - await self._drained_event.wait() - - # --- Model response --- - async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") + pass async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - await self.wait_for_gate() - try: - weight_version = self.weight_version - result = await self._get_model_response(messages, **kwargs) - result.weight_version = weight_version - return result - finally: - self.on_model_call_complete() - - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: - """ - Assemble model output from a token output. - """ - raise NotImplementedError("assemble_model_output is not implemented") - - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TokenOutput: - """Obtain the token output from the given token input.""" - raise NotImplementedError("get_token_output_from_token_input is not implemented") - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False + raise NotImplementedError("get_model_response is not implemented") async def wake_up(self): pass diff --git a/rllm/engine/rollout/tinker_engine.py b/rllm/engine/rollout/tinker_engine.py index 51d9cc354..c6e35e211 100644 --- a/rllm/engine/rollout/tinker_engine.py +++ b/rllm/engine/rollout/tinker_engine.py @@ -1,76 +1,37 @@ -from typing import cast +import json import tinker from tinker.types import ModelInput from tinker_cookbook import model_info, renderers -from typing_extensions import override # need to use typing_extensions for python < 3.12 from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.engine.rollout.types import ImageProcessor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput -from rllm.parser.tinker_parser import TinkerChatTemplateParser +from rllm.parser import ChatTemplateParser +from rllm.tools.tool_base import ToolCall from rllm.workflows import TerminationEvent, TerminationReason -""" -Utility functions for Tinker engine. Partly borrowed from -https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py -""" - - -def _flat_token_input_to_model_input(token_input: TinkerTokenInput) -> ModelInput: - """Convert a flat token input to a ModelInput.""" - if not token_input: # empty list - return ModelInput(chunks=[]) - - out: list[tinker.ModelInputChunk] = [] - current_text_chunk: list[int] = [] - - def flush_text_chunk(): - if current_text_chunk: - out.append(tinker.EncodedTextChunk(tokens=current_text_chunk)) - current_text_chunk.clear() - - for elem in token_input: - if isinstance(elem, int): - current_text_chunk.append(elem) - else: - flush_text_chunk() - out.append(elem) - - flush_text_chunk() # final clear up - return tinker.ModelInput(chunks=out) - - -def _flat_token_input_length(token_input: TokenInput) -> int: - """Get the length of a flat token input. This nicely handles both text and image inputs""" - length = 0 - for elem in token_input: - if isinstance(elem, int): - length += 1 - else: - length += elem.length - return length - class TinkerEngine(RolloutEngine): """ RolloutEngine implementation using Tinker for model inference. - - Wraps the tinker renderer with a TinkerChatTemplateParser, which provides - unified prompt building (including tool spec injection) and response parsing - (content, reasoning, tool_calls). """ def __init__( self, model_name: str, - tokenizer: Tokenizer, + tokenizer, service_client: tinker.ServiceClient, - base_url: str | None = None, + sampling_client: tinker.SamplingClient = None, max_prompt_length: int = 4096, max_response_length: int = 4096, max_model_length: int = 32768, sampling_params: dict | None = None, - image_processor: ImageProcessor | None = None, + val_sampling_params: dict | None = None, + bypass_render_with_parser: bool = False, + processor=None, + image_processor=None, + disable_thinking: bool = False, + accumulate_reasoning: bool = False, + reasoning_effort: str = "medium", renderer_name: str | None = None, **kwargs, ): @@ -81,43 +42,55 @@ def __init__( model_name: Name of the model to use tokenizer: Tokenizer for encoding/decoding service_client: Tinker ServiceClient instance - base_url: Tinker service URL (default = null for local) + sampling_client: Tinker SamplingClient instance max_prompt_length: Maximum prompt length in tokens max_response_length: Maximum response length in tokens max_model_length: Maximum total length (prompt + response) in tokens - sampling_params: Default sampling parameters (temperature, top_p, etc.) + sampling_params: Default sampling parameters for training (temperature, top_p, etc.) + val_sampling_params: Sampling parameters for validation (defaults to sampling_params if not provided) + bypass_render_with_parser: If True, use ChatTemplateParser instead of Tinker's renderer + processor: Optional processor for multimodal models (used when bypass_render_with_parser=True) image_processor: Optional image processor for vision-language models (used with renderer) - renderer_name: Optional renderer name to use (None = auto-detect from model) - kwargs: Additional keyword arguments - - strip_thinking_from_history: Whether to strip thinking from history (only for Qwen3Renderer) + disable_thinking: Whether to disable thinking in generation prompt (used when bypass_render_with_parser=True) + accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) + renderer_name: Override renderer name (None = auto-detect from model) """ - super().__init__() - self.base_url = base_url self.model_name = model_name self.max_prompt_length = max_prompt_length self.max_response_length = max_response_length - self.max_model_length = max_model_length - 1 + self.max_model_length = max_model_length - 1 # Reserve 1 token for logprob computation self.tokenizer = tokenizer + self.sampling_params = sampling_params or {} + self.val_sampling_params = val_sampling_params or self.sampling_params + self.validate = False + self.bypass_render_with_parser = bypass_render_with_parser + self.accumulate_reasoning = accumulate_reasoning + self.reasoning_effort = reasoning_effort - self.train_sampling_params = dict(sampling_params.get("train", {})) if sampling_params else {} - self.val_sampling_params = dict(sampling_params.get("val", {})) if sampling_params else {} # Initialize Tinker service client self.service_client = service_client - # Initialize the renderer and wrap with TinkerChatTemplateParser - renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) - renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) - - if "strip_thinking_from_history" in kwargs and isinstance(kwargs["strip_thinking_from_history"], bool) and hasattr(renderer, "strip_thinking_from_history"): - renderer.strip_thinking_from_history = kwargs["strip_thinking_from_history"] - - self.chat_parser: TinkerChatTemplateParser = TinkerChatTemplateParser(renderer) - self.stop_sequences = self.chat_parser.stop_sequences - - # Sampling client will be set via set_sampling_client() - self.sampling_client: tinker.SamplingClient | None = None + if bypass_render_with_parser: + self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=disable_thinking) + self.renderer = None + if hasattr(self.chat_parser, "stop_sequences") and self.chat_parser.stop_sequences: + self.stop_sequences = self.chat_parser.stop_sequences + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + self.stop_sequences = [tokenizer.eos_token] + else: + raise ValueError("No stop sequences found for tokenizer or chat parser") + else: + # Use explicit renderer_name if provided, otherwise auto-detect + renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) + # Pass image_processor for VLM support with Tinker renderer + self.renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) + self.chat_parser = None + self.stop_sequences = self.renderer.get_stop_sequences() + + # Sampling client can be set later via set_sampling_client() + self.sampling_client = sampling_client - def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: + def set_sampling_client(self, sampling_client): """ Set the sampling client for inference. @@ -126,6 +99,34 @@ def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: """ self.sampling_client = sampling_client + def _convert_images_to_content_list(self, messages: list[dict]) -> list[dict]: + """ + Convert messages from standard format to Tinker renderer format. + + Standard format: {"role": "user", "content": "text", "images": [PIL.Image]} + Tinker format: {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "..."}]} + + Args: + messages: List of messages in standard format + + Returns: + List of messages in Tinker renderer format + """ + converted = [] + for msg in messages: + if "images" in msg and msg["images"]: + # Convert to content list format + content_list = [] + for img in msg["images"]: + content_list.append({"type": "image", "image": img}) + content_list.append({"type": "text", "text": msg.get("content", "")}) + converted.append({**msg, "content": content_list}) + # Remove the images key since it's now in content + del converted[-1]["images"] + else: + converted.append(msg) + return converted + def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> int: """ Prepare max_tokens parameter, adjusting for max_model_length if needed. @@ -148,80 +149,157 @@ def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> return max_tokens - @property - def supports_token_in_token_out(self) -> bool: - """Tinker sampling client does support returning prompt_ids, so this is true.""" - return True - - @override - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TinkerTokenOutput: + async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: """ - Generate a sampled sequence from a given token input. + Generate model response for a given set of messages. + + Args: + messages: List of message dictionaries (OpenAI format) + **kwargs: Additional parameters including: + - application_id: Session/application ID for tracing + - validate: Whether this is validation (for greedy decoding) + - enforce_max_prompt_length: Whether to enforce max prompt length + - tools: List of tools (used when bypass_render_with_parser=True) + - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) + + Returns: + ModelOutput with generated text and metadata """ - token_input = cast(TinkerTokenInput, token_input) if self.sampling_client is None: raise RuntimeError("Sampling client not set. Call set_sampling_client() first.") - input_length = _flat_token_input_length(token_input) - + # Extract kwargs + kwargs.pop("application_id", None) + validate = kwargs.pop("validate", False) or self.validate enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - if enforce_max_prompt_length and input_length > min(self.max_prompt_length, self.max_model_length): - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - - # prepare sampling params - sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() - - requested_max_tokens = kwargs.pop("max_tokens", kwargs.pop("max_new_tokens", self.max_response_length)) - requested_max_tokens = sampling_params.pop("max_tokens", requested_max_tokens) - max_tokens = self._prepare_max_tokens(requested_max_tokens, input_length) - - if "temperature" in kwargs: - sampling_params["temperature"] = kwargs["temperature"] - if "top_p" in kwargs: - sampling_params["top_p"] = kwargs["top_p"] - if "top_k" in kwargs: - sampling_params["top_k"] = kwargs["top_k"] - - tinker_sampling_params = tinker.types.SamplingParams( - max_tokens=max_tokens, - stop=self.stop_sequences, # type: ignore - **sampling_params, - ) - # call sampling client - model_input = _flat_token_input_to_model_input(token_input) - sample_response: tinker.SampleResponse = await self.sampling_client.sample_async( - prompt=model_input, - num_samples=1, - sampling_params=tinker_sampling_params, - ) + sampling_params = self.val_sampling_params if validate else self.sampling_params - # return sampled sequence from sample response - return sample_response.sequences[0] - - @override - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: - """ - Assemble model output from a sampled sequence. - """ - sampled_sequence = cast(TinkerTokenOutput, token_output) - response_tokens, logprobs = sampled_sequence.tokens, sampled_sequence.logprobs - - # Parse response using parser (handles content, reasoning, tool_calls) - parsed_output = self.chat_parser.parse_completion(response_tokens) - content = parsed_output.get("content", "") - reasoning = parsed_output.get("reasoning", "") - tool_calls = parsed_output.get("tool_calls", []) - - # decode full text - completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) # type: ignore - finish_reason = sampled_sequence.stop_reason - # special handling for prompt ids, we will break any EncodedTextChunk into ints - prompt_ids = [] - for elem in token_input: - if isinstance(elem, tinker.EncodedTextChunk): - prompt_ids.extend(elem.tokens) + # Extract parser-specific kwargs + tools = kwargs.pop("tools", []) + accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", self.reasoning_effort) + + if self.bypass_render_with_parser: + # Use ChatTemplateParser + prompt = self.chat_parser.parse( + messages, + add_generation_prompt=True, + is_first_msg=True, + tools=tools, + reasoning_effort=reasoning_effort, + accumulate_reasoning=accumulate_reasoning, + ) + prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) + prompt_length = len(prompt_ids) + + # Check prompt length + if enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length): + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + + # Dynamically adjust max_tokens based on prompt length + default_max_tokens = sampling_params.get("max_tokens", self.max_response_length) + requested_max_tokens = kwargs.get("max_tokens", kwargs.get("max_new_tokens", default_max_tokens)) + max_tokens = self._prepare_max_tokens(requested_max_tokens, prompt_length) + + # Prepare sampling params (override defaults with kwargs) + sampling_params = tinker.types.SamplingParams( + max_tokens=max_tokens, + stop=self.stop_sequences, + temperature=kwargs.get("temperature", sampling_params.get("temperature", 1.0)), + top_p=kwargs.get("top_p", sampling_params.get("top_p", 1.0)), + ) + + # Convert prompt to Tinker prompt format + tinker_prompt = ModelInput.from_ints(prompt_ids) + + # Call Tinker sampling API + sample_response = await self.sampling_client.sample_async( + prompt=tinker_prompt, + num_samples=1, + sampling_params=sampling_params, + ) + + # Extract response tokens and logprobs + response_tokens = sample_response.sequences[0].tokens + logprobs = sample_response.sequences[0].logprobs + + # Parse response using parser + parsed_output = self.chat_parser.parse_completion(response_tokens) + + content = parsed_output.get("content", "") + reasoning = parsed_output.get("reasoning", "") + tool_calls = parsed_output.get("tool_calls", []) + + # Decode full text + completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) + else: + # Use Tinker renderer (original behavior) + # Convert standard image format to Tinker renderer format + converted_messages = self._convert_images_to_content_list(messages) + # Build prompt using renderer (converts messages to Tinker prompt) + tinker_prompt = self.renderer.build_generation_prompt(converted_messages) + + # For VLM prompts with ImageChunks, to_ints() may not be supported + try: + prompt_ids = tinker_prompt.to_ints() + prompt_length = len(prompt_ids) + except ValueError: + # Prompt contains ImageChunks - skip length enforcement for VLM + prompt_ids = [] + prompt_length = 0 + + # Check prompt length (only for text-only prompts) + if prompt_length > 0 and enforce_max_prompt_length and (prompt_length > self.max_prompt_length or prompt_length > self.max_model_length): + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + + # Dynamically adjust max_tokens based on prompt length + default_max_tokens = sampling_params.get("max_tokens", self.max_response_length) + requested_max_tokens = kwargs.get("max_tokens", kwargs.get("max_new_tokens", default_max_tokens)) + max_tokens = self._prepare_max_tokens(requested_max_tokens, prompt_length) if prompt_length > 0 else requested_max_tokens + + # Prepare sampling params (override defaults with kwargs) + sampling_params = tinker.types.SamplingParams( + max_tokens=max_tokens, + stop=self.stop_sequences, + temperature=kwargs.get("temperature", sampling_params.get("temperature", 1.0)), + top_p=kwargs.get("top_p", sampling_params.get("top_p", 1.0)), + ) + + # Call Tinker sampling API + sample_response = await self.sampling_client.sample_async( + prompt=tinker_prompt, + num_samples=1, + sampling_params=sampling_params, + ) + + # Extract response tokens and logprobs + response_tokens = sample_response.sequences[0].tokens + logprobs = sample_response.sequences[0].logprobs + + # Parse response using renderer + parsed_msg, _ = self.renderer.parse_response(response_tokens) + raw_content = parsed_msg["content"] + tool_calls = [] + for tc in parsed_msg.get("tool_calls", []): + try: + tool_calls.append(ToolCall(name=tc.function.name, arguments=json.loads(tc.function.arguments))) + except (json.JSONDecodeError, AttributeError): + continue + + if isinstance(raw_content, list): + reasoning = next((p["thinking"] for p in raw_content if p["type"] == "thinking"), "") + content = next((p["text"] for p in raw_content if p["type"] == "text"), "") else: - prompt_ids.append(elem) + content = raw_content + reasoning = "" + + # Decode full text + completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) + + # Determine finish reason + finish_reason = "stop" + if len(response_tokens) >= sampling_params.max_tokens: + finish_reason = "length" return ModelOutput( text=completion_text, @@ -231,41 +309,11 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp prompt_ids=prompt_ids, completion_ids=response_tokens, logprobs=logprobs, - routing_matrices=None, - prompt_length=_flat_token_input_length(token_input), + prompt_length=prompt_length, completion_length=len(response_tokens), finish_reason=finish_reason, - metrics=None, ) - @override - async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - """ - Generate model response for a given set of messages. - - Args: - messages: List of message dictionaries (OpenAI format) - **kwargs: Additional parameters including: - - application_id: Session/application ID for tracing - - enforce_max_prompt_length: Whether to enforce max prompt length - - tools: List of tools for tool-augmented generation - - Returns: - ModelOutput with generated text and metadata - """ - # Extract unused kwargs - kwargs.pop("application_id", None) - - # Extract tools - tools = kwargs.pop("tools", []) - - # Build prompt using TinkerChatTemplateParser (handles tools, images, etc.) - tinker_prompt = self.chat_parser.build_prompt(messages, tools=tools) - token_input: TinkerTokenInput = tinker_prompt.chunks - - sampled_sequence = await self.get_token_output_from_token_input(token_input=token_input, **kwargs) - return self.assemble_model_output(token_input=token_input, token_output=sampled_sequence) - async def compute_logprobs(self, ids: list[int]) -> list[float]: ids = ids[: self.max_model_length] return await self.sampling_client.compute_logprobs_async(ModelInput.from_ints(ids)) diff --git a/rllm/engine/rollout/verl_engine.py b/rllm/engine/rollout/verl_engine.py index 2d495ecab..3db4dde28 100644 --- a/rllm/engine/rollout/verl_engine.py +++ b/rllm/engine/rollout/verl_engine.py @@ -1,18 +1,15 @@ import uuid -from typing import cast -from omegaconf import DictConfig -from typing_extensions import override from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AsyncLLMServerManager +from verl.workers.rollout.replica import TokenOutput from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput, VerlTokenOutput from rllm.parser import ChatTemplateParser from rllm.workflows import TerminationEvent, TerminationReason class VerlEngine(RolloutEngine): - def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs): + def __init__(self, config, rollout_manager, tokenizer, processor=None, **kwargs): self.config = config if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: @@ -48,38 +45,25 @@ def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokeni print(f"train_sampling_params: {self.train_sampling_params}") print(f"val_sampling_params: {self.val_sampling_params}") - @property - def supports_token_in_token_out(self) -> bool: - return True + self.validate = False - @override - async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> VerlTokenOutput: - token_input = cast(list[int], token_input) - - input_length = len(token_input) + async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: application_id = kwargs.pop("application_id", str(uuid.uuid4())) + validate = self.validate or kwargs.pop("validate", False) enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) - if enforce_max_prompt_length and input_length > self.max_prompt_length: - raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + # these go to the parser + tools = kwargs.pop("tools", []) + accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) - sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() + sampling_params = self.val_sampling_params.copy() if self.validate or validate else self.train_sampling_params.copy() sampling_params.update(kwargs) + max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) # starting from verl 0.7.0, we can pass in per-turn max_tokens into the sampling_params sampling_params["max_tokens"] = max_tokens - token_output = await self.server_manager.generate(request_id=application_id, prompt_ids=token_input, sampling_params=sampling_params) - return token_output - - @override - async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - # these go to the parser - tools = kwargs.pop("tools", []) - accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) - reasoning_effort = kwargs.pop("reasoning_effort", "medium") - - prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning, reasoning_effort=reasoning_effort) + prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning) request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: @@ -93,27 +77,15 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp multi_modal_inputs = None prompt_ids = request_prompt_ids - token_output: TokenOutput = await self.get_token_output_from_token_input(token_input=request_prompt_ids, **kwargs) - extra_kwargs = dict(prompt_ids=prompt_ids, multi_modal_inputs=multi_modal_inputs) - return self.assemble_model_output(token_input=request_prompt_ids, token_output=token_output, **extra_kwargs) - - @override - def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput, **kwargs) -> ModelOutput: - prompt_ids = kwargs.pop("prompt_ids", None) - multi_modal_inputs = kwargs.pop("multi_modal_inputs", None) - prompt_length = len(prompt_ids) if prompt_ids is not None else 0 - - token_output = cast(VerlTokenOutput, token_output) - completion_ids = token_output.token_ids - logprobs = token_output.log_probs + prompt_length = len(prompt_ids) + if enforce_max_prompt_length and prompt_length > self.max_prompt_length: + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) - # convert the stop reason from verl back to the standard finish reason TODO(listar2000): check backward-compatibility - reason_mapping = {"aborted": "abort", "completed": "stop"} - if token_output.stop_reason is not None: - finish_reason = reason_mapping.get(token_output.stop_reason, token_output.stop_reason) - else: - finish_reason = "stop" + token_output: TokenOutput = await self.server_manager.generate(request_id=application_id, prompt_ids=request_prompt_ids, image_data=image_data, sampling_params=sampling_params) # type: ignore + completion_ids: list[int] = token_output.token_ids + logprobs: list[float] = token_output.log_probs + finish_reason = token_output.stop_reason completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) # TODO: implement parse_completion for the standard parser parsed_output = self.chat_parser.parse_completion(completion_ids) diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index b60999a6c..085b4e8e2 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -11,7 +11,7 @@ from tqdm import tqdm from rllm.agents.agent import Episode -from rllm.engine.rollout import RolloutEngine +from rllm.experimental.rollout import RolloutEngine from rllm.utils import colorful_print from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index e529f2aad..d5ce100ac 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -16,7 +16,7 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.engine.rollout import RolloutEngine +from rllm.experimental.rollout import RolloutEngine from rllm.experimental.common.advantage import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups if TYPE_CHECKING: diff --git a/rllm/experimental/rollout/__init__.py b/rllm/experimental/rollout/__init__.py index 7fe19012a..50ab03477 100644 --- a/rllm/experimental/rollout/__init__.py +++ b/rllm/experimental/rollout/__init__.py @@ -1,20 +1,18 @@ -# Backward compatibility: re-export from canonical location -from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine # noqa: F401 -from rllm.engine.rollout.types import ( # noqa: F401 - TinkerTokenInput, - TinkerTokenOutput, - TokenInput, - Tokenizer, - TokenOutput, - VerlTokenInput, - VerlTokenOutput, -) +from typing import TYPE_CHECKING + +from .rollout_engine import ModelOutput, RolloutEngine +from .types import TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput, VerlTokenInput, VerlTokenOutput + +if TYPE_CHECKING: + from .tinker_engine import TinkerEngine + from .verl_engine import VerlEngine __all__ = [ "ModelOutput", "RolloutEngine", "TinkerEngine", "VerlEngine", + # Token types "TokenInput", "TokenOutput", "TinkerTokenInput", @@ -26,14 +24,13 @@ def __getattr__(name): - # Lazy imports for engines with heavy dependencies if name == "TinkerEngine": - from rllm.engine.rollout.tinker_engine import TinkerEngine as _TinkerEngine + from .tinker_engine import TinkerEngine as _TinkerEngine return _TinkerEngine if name == "VerlEngine": try: - from rllm.engine.rollout.verl_engine import VerlEngine as _VerlEngine + from .verl_engine import VerlEngine as _VerlEngine return _VerlEngine except Exception: diff --git a/rllm/engine/rollout/completer.py b/rllm/experimental/rollout/completer.py similarity index 97% rename from rllm/engine/rollout/completer.py rename to rllm/experimental/rollout/completer.py index a7be82ce5..8aa034124 100644 --- a/rllm/engine/rollout/completer.py +++ b/rllm/experimental/rollout/completer.py @@ -14,8 +14,8 @@ from typing import TYPE_CHECKING, Any from rllm.agents.agent import Step -from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.engine.rollout.types import TokenInput, Tokenizer, TokenOutput +from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine +from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput from rllm.parser import ChatTemplateParser diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index 7146be416..eaa2d38a0 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,5 +1,5 @@ -from __future__ import annotations - +import asyncio +import logging from dataclasses import dataclass from typing import TYPE_CHECKING @@ -8,6 +8,8 @@ from rllm.parser import ChatTemplateParser from rllm.tools.tool_base import ToolCall +logger = logging.getLogger(__name__) + @dataclass class ModelOutput: @@ -20,9 +22,12 @@ class ModelOutput: multi_modal_inputs: dict[str, list] | None = None logprobs: list[float] | None = None # completion logprobs prompt_logprobs: list[float] | None = None # prompt logprobs aligned to prompt_ids + routing_matrices: list[str] | None = None # per-token routing matrices (R3, transient) prompt_length: int = 0 completion_length: int = 0 finish_reason: str | None = None + weight_version: int | None = None # policy version at time of generation + metrics: dict | None = None # per-turn server metrics (e.g. ttft, queue durations) def to_dict(self): return { @@ -38,6 +43,8 @@ def to_dict(self): "prompt_length": self.prompt_length, "completion_length": self.completion_length, "finish_reason": self.finish_reason, + "weight_version": self.weight_version, + "metrics": self.metrics, } @classmethod @@ -57,6 +64,8 @@ def from_dict(cls, data: dict): prompt_length=data.get("prompt_length", 0), completion_length=data.get("completion_length", 0), finish_reason=data.get("finish_reason"), + weight_version=data.get("weight_version"), + metrics=data.get("metrics"), ) @@ -66,10 +75,64 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - pass + # Gate mechanism for pausing model calls during weight sync + self._gate: asyncio.Event = asyncio.Event() + self._gate.set() # open by default + self._active_calls: int = 0 + self._drained_event: asyncio.Event = asyncio.Event() + self._drained_event.set() # initially drained (no active calls) + self.weight_version: int = 0 + + # --- Gate mechanism --- + + def close_gate(self) -> None: + """Close the gate. New model calls will block at wait_for_gate().""" + logger.info(f"[RolloutEngine] Closing gate. Active calls: {self._active_calls}") + self._gate.clear() + + def open_gate(self) -> None: + """Open the gate, releasing any blocked model calls.""" + logger.info(f"[RolloutEngine] Opening gate. Active calls: {self._active_calls}") + self._gate.set() + + def on_model_call_complete(self) -> None: + """Unregister active call. Engines will call this at the END of get_model_response().""" + self._active_calls -= 1 + if self._active_calls <= 0: + self._active_calls = 0 + self._drained_event.set() + logger.debug("[RolloutEngine] All active calls drained.") + else: + logger.debug(f"[RolloutEngine] Model call complete. Active calls: {self._active_calls}") + + async def wait_for_gate(self) -> None: + """Wait until gate is open, then register as active call. Engines will call this at the START of get_model_response().""" + if not self._gate.is_set(): + logger.info(f"[RolloutEngine] Waiting for gate to open. Active calls: {self._active_calls}") + await self._gate.wait() + self._active_calls += 1 + self._drained_event.clear() + logger.debug(f"[RolloutEngine] Gate passed. Active calls: {self._active_calls}") + + async def wait_for_drain(self) -> None: + """Wait until all active model calls complete. Used during weight sync.""" + if not self._drained_event.is_set(): + logger.info(f"[RolloutEngine] Waiting for drain. Active calls: {self._active_calls}") + await self._drained_event.wait() + + # --- Model response --- + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - raise NotImplementedError("get_model_response is not implemented") + await self.wait_for_gate() + try: + weight_version = self.weight_version + result = await self._get_model_response(messages, **kwargs) + result.weight_version = weight_version + return result + finally: + self.on_model_call_complete() def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: """ @@ -81,13 +144,13 @@ async def get_token_output_from_token_input(self, token_input: TokenInput, **kwa """Obtain the token output from the given token input.""" raise NotImplementedError("get_token_output_from_token_input is not implemented") + @property + def supports_token_in_token_out(self) -> bool: + """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" + return False + async def wake_up(self): pass async def sleep(self): pass - - @property - def supports_token_in_token_out(self) -> bool: - """Whether the engine supports token-in-token-out (TITO) generation. Defaults to false.""" - return False diff --git a/rllm/experimental/rollout/tinker_engine.py b/rllm/experimental/rollout/tinker_engine.py new file mode 100644 index 000000000..f428230ca --- /dev/null +++ b/rllm/experimental/rollout/tinker_engine.py @@ -0,0 +1,271 @@ +from typing import cast + +import tinker +from tinker.types import ModelInput +from tinker_cookbook import model_info, renderers +from typing_extensions import override # need to use typing_extensions for python < 3.12 + +from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine +from rllm.experimental.rollout.types import ImageProcessor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput +from rllm.parser.tinker_parser import TinkerChatTemplateParser +from rllm.workflows import TerminationEvent, TerminationReason + +""" +Utility functions for Tinker engine. Partly borrowed from +https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py +""" + + +def _flat_token_input_to_model_input(token_input: TinkerTokenInput) -> ModelInput: + """Convert a flat token input to a ModelInput.""" + if not token_input: # empty list + return ModelInput(chunks=[]) + + out: list[tinker.ModelInputChunk] = [] + current_text_chunk: list[int] = [] + + def flush_text_chunk(): + if current_text_chunk: + out.append(tinker.EncodedTextChunk(tokens=current_text_chunk)) + current_text_chunk.clear() + + for elem in token_input: + if isinstance(elem, int): + current_text_chunk.append(elem) + else: + flush_text_chunk() + out.append(elem) + + flush_text_chunk() # final clear up + return tinker.ModelInput(chunks=out) + + +def _flat_token_input_length(token_input: TokenInput) -> int: + """Get the length of a flat token input. This nicely handles both text and image inputs""" + length = 0 + for elem in token_input: + if isinstance(elem, int): + length += 1 + else: + length += elem.length + return length + + +class TinkerEngine(RolloutEngine): + """ + RolloutEngine implementation using Tinker for model inference. + + Wraps the tinker renderer with a TinkerChatTemplateParser, which provides + unified prompt building (including tool spec injection) and response parsing + (content, reasoning, tool_calls). + """ + + def __init__( + self, + model_name: str, + tokenizer: Tokenizer, + service_client: tinker.ServiceClient, + base_url: str | None = None, + max_prompt_length: int = 4096, + max_response_length: int = 4096, + max_model_length: int = 32768, + sampling_params: dict | None = None, + image_processor: ImageProcessor | None = None, + renderer_name: str | None = None, + **kwargs, + ): + """ + Initialize TinkerEngine. + + Args: + model_name: Name of the model to use + tokenizer: Tokenizer for encoding/decoding + service_client: Tinker ServiceClient instance + base_url: Tinker service URL (default = null for local) + max_prompt_length: Maximum prompt length in tokens + max_response_length: Maximum response length in tokens + max_model_length: Maximum total length (prompt + response) in tokens + sampling_params: Default sampling parameters (temperature, top_p, etc.) + image_processor: Optional image processor for vision-language models (used with renderer) + renderer_name: Optional renderer name to use (None = auto-detect from model) + kwargs: Additional keyword arguments + - strip_thinking_from_history: Whether to strip thinking from history (only for Qwen3Renderer) + """ + super().__init__() + self.base_url = base_url + self.model_name = model_name + self.max_prompt_length = max_prompt_length + self.max_response_length = max_response_length + self.max_model_length = max_model_length - 1 + self.tokenizer = tokenizer + + self.train_sampling_params = dict(sampling_params.get("train", {})) if sampling_params else {} + self.val_sampling_params = dict(sampling_params.get("val", {})) if sampling_params else {} + # Initialize Tinker service client + self.service_client = service_client + + # Initialize the renderer and wrap with TinkerChatTemplateParser + renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) + renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) + + if "strip_thinking_from_history" in kwargs and isinstance(kwargs["strip_thinking_from_history"], bool) and hasattr(renderer, "strip_thinking_from_history"): + renderer.strip_thinking_from_history = kwargs["strip_thinking_from_history"] + + self.chat_parser: TinkerChatTemplateParser = TinkerChatTemplateParser(renderer) + self.stop_sequences = self.chat_parser.stop_sequences + + # Sampling client will be set via set_sampling_client() + self.sampling_client: tinker.SamplingClient | None = None + + def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: + """ + Set the sampling client for inference. + + Args: + sampling_client: Tinker SamplingClient instance + """ + self.sampling_client = sampling_client + + def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> int: + """ + Prepare max_tokens parameter, adjusting for max_model_length if needed. + + Args: + requested_max_tokens: The requested max_tokens value + prompt_length: The length of the prompt in tokens + + Returns: + Adjusted max_tokens value + """ + max_tokens = requested_max_tokens + + # Adjust for prompt length if max_model_length is set + if self.max_model_length: + remaining = self.max_model_length - prompt_length + if remaining <= max_tokens: + max_tokens = remaining + print(f"Warning: Decreasing max_tokens to {max_tokens} to stay within max_model_length") + + return max_tokens + + @property + def supports_token_in_token_out(self) -> bool: + """Tinker sampling client does support returning prompt_ids, so this is true.""" + return True + + @override + async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> TinkerTokenOutput: + """ + Generate a sampled sequence from a given token input. + """ + token_input = cast(TinkerTokenInput, token_input) + if self.sampling_client is None: + raise RuntimeError("Sampling client not set. Call set_sampling_client() first.") + + input_length = _flat_token_input_length(token_input) + + enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) + if enforce_max_prompt_length and input_length > min(self.max_prompt_length, self.max_model_length): + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + + # prepare sampling params + sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() + + requested_max_tokens = kwargs.pop("max_tokens", kwargs.pop("max_new_tokens", self.max_response_length)) + requested_max_tokens = sampling_params.pop("max_tokens", requested_max_tokens) + max_tokens = self._prepare_max_tokens(requested_max_tokens, input_length) + + if "temperature" in kwargs: + sampling_params["temperature"] = kwargs["temperature"] + if "top_p" in kwargs: + sampling_params["top_p"] = kwargs["top_p"] + if "top_k" in kwargs: + sampling_params["top_k"] = kwargs["top_k"] + + tinker_sampling_params = tinker.types.SamplingParams( + max_tokens=max_tokens, + stop=self.stop_sequences, # type: ignore + **sampling_params, + ) + # call sampling client + model_input = _flat_token_input_to_model_input(token_input) + sample_response: tinker.SampleResponse = await self.sampling_client.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=tinker_sampling_params, + ) + + # return sampled sequence from sample response + return sample_response.sequences[0] + + @override + def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: + """ + Assemble model output from a sampled sequence. + """ + sampled_sequence = cast(TinkerTokenOutput, token_output) + response_tokens, logprobs = sampled_sequence.tokens, sampled_sequence.logprobs + + # Parse response using parser (handles content, reasoning, tool_calls) + parsed_output = self.chat_parser.parse_completion(response_tokens) + content = parsed_output.get("content", "") + reasoning = parsed_output.get("reasoning", "") + tool_calls = parsed_output.get("tool_calls", []) + + # decode full text + completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) # type: ignore + finish_reason = sampled_sequence.stop_reason + # special handling for prompt ids, we will break any EncodedTextChunk into ints + prompt_ids = [] + for elem in token_input: + if isinstance(elem, tinker.EncodedTextChunk): + prompt_ids.extend(elem.tokens) + else: + prompt_ids.append(elem) + + return ModelOutput( + text=completion_text, + content=content, + reasoning=reasoning, + tool_calls=tool_calls, + prompt_ids=prompt_ids, + completion_ids=response_tokens, + logprobs=logprobs, + routing_matrices=None, + prompt_length=_flat_token_input_length(token_input), + completion_length=len(response_tokens), + finish_reason=finish_reason, + metrics=None, + ) + + @override + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + """ + Generate model response for a given set of messages. + + Args: + messages: List of message dictionaries (OpenAI format) + **kwargs: Additional parameters including: + - application_id: Session/application ID for tracing + - enforce_max_prompt_length: Whether to enforce max prompt length + - tools: List of tools for tool-augmented generation + + Returns: + ModelOutput with generated text and metadata + """ + # Extract unused kwargs + kwargs.pop("application_id", None) + + # Extract tools + tools = kwargs.pop("tools", []) + + # Build prompt using TinkerChatTemplateParser (handles tools, images, etc.) + tinker_prompt = self.chat_parser.build_prompt(messages, tools=tools) + token_input: TinkerTokenInput = tinker_prompt.chunks + + sampled_sequence = await self.get_token_output_from_token_input(token_input=token_input, **kwargs) + return self.assemble_model_output(token_input=token_input, token_output=sampled_sequence) + + async def compute_logprobs(self, ids: list[int]) -> list[float]: + ids = ids[: self.max_model_length] + return await self.sampling_client.compute_logprobs_async(ModelInput.from_ints(ids)) diff --git a/rllm/engine/rollout/types.py b/rllm/experimental/rollout/types.py similarity index 100% rename from rllm/engine/rollout/types.py rename to rllm/experimental/rollout/types.py diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py new file mode 100644 index 000000000..4c08073e7 --- /dev/null +++ b/rllm/experimental/rollout/verl_engine.py @@ -0,0 +1,133 @@ +import uuid +from typing import cast + +from omegaconf import DictConfig +from typing_extensions import override +from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AsyncLLMServerManager + +from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine +from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput, VerlTokenOutput +from rllm.parser import ChatTemplateParser +from rllm.workflows import TerminationEvent, TerminationReason + + +class VerlEngine(RolloutEngine): + def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs): + self.config = config + + if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: + raise ValueError(f"VerlEngine only supports vllm or sglang rollout, but got {config.actor_rollout_ref.rollout.name}") + + self.rollout_manager: AgentLoopManager = rollout_manager + # reconstruct the servers list from the server_addresses and server_handles (Verl 0.7.0+) + servers = zip(rollout_manager.server_addresses, rollout_manager.server_handles, strict=True) + self.server_manager = AsyncLLMServerManager(config, servers=servers, load_balancer_handle=rollout_manager.global_load_balancer) + + self.tokenizer = tokenizer + self.processor = processor + self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=config.get("rllm", {}).get("disable_thinking", False)) + + self.max_prompt_length = config.data.max_prompt_length + self.max_response_length = config.data.max_response_length + self.accumulate_reasoning = config.get("rllm", {}).get("accumulate_reasoning", False) + + self.train_sampling_params = dict( + temperature=0.0 if config.actor_rollout_ref.rollout.do_sample is False else config.actor_rollout_ref.rollout.temperature, + top_k=config.actor_rollout_ref.rollout.top_k, + top_p=config.actor_rollout_ref.rollout.top_p, + logprobs=1, + ) + + self.val_sampling_params = dict( + temperature=0.0 if config.actor_rollout_ref.rollout.val_kwargs.do_sample is False else config.actor_rollout_ref.rollout.val_kwargs.temperature, + top_k=config.actor_rollout_ref.rollout.val_kwargs.top_k, + top_p=config.actor_rollout_ref.rollout.val_kwargs.top_p, + logprobs=1, + ) + + print(f"train_sampling_params: {self.train_sampling_params}") + print(f"val_sampling_params: {self.val_sampling_params}") + + @property + def supports_token_in_token_out(self) -> bool: + return True + + @override + async def get_token_output_from_token_input(self, token_input: TokenInput, **kwargs) -> VerlTokenOutput: + token_input = cast(list[int], token_input) + + input_length = len(token_input) + application_id = kwargs.pop("application_id", str(uuid.uuid4())) + enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True) + + if enforce_max_prompt_length and input_length > self.max_prompt_length: + raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED) + + sampling_params = self.val_sampling_params.copy() if self.is_validation else self.train_sampling_params.copy() + sampling_params.update(kwargs) + max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length)) + # starting from verl 0.7.0, we can pass in per-turn max_tokens into the sampling_params + sampling_params["max_tokens"] = max_tokens + + token_output = await self.server_manager.generate(request_id=application_id, prompt_ids=token_input, sampling_params=sampling_params) + return token_output + + @override + async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: + # these go to the parser + tools = kwargs.pop("tools", []) + accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", "medium") + + prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning, reasoning_effort=reasoning_effort) + request_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) # list[int] + + if any(msg.get("images", None) is not None and msg["role"] == "user" for msg in messages) and self.processor is not None: + image_data = self.chat_parser.process_image_data(messages) # list[PIL.Image.Image] + model_inputs = self.processor(text=[prompt], images=image_data) + prompt_ids = model_inputs.pop("input_ids")[0] # list[int] + model_inputs.pop("attention_mask") + multi_modal_inputs = dict(model_inputs) + else: + image_data = None + multi_modal_inputs = None + prompt_ids = request_prompt_ids + + token_output: TokenOutput = await self.get_token_output_from_token_input(token_input=request_prompt_ids, **kwargs) + extra_kwargs = dict(prompt_ids=prompt_ids, multi_modal_inputs=multi_modal_inputs) + return self.assemble_model_output(token_input=request_prompt_ids, token_output=token_output, **extra_kwargs) + + @override + def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput, **kwargs) -> ModelOutput: + prompt_ids = kwargs.pop("prompt_ids", None) + multi_modal_inputs = kwargs.pop("multi_modal_inputs", None) + prompt_length = len(prompt_ids) if prompt_ids is not None else 0 + + token_output = cast(VerlTokenOutput, token_output) + completion_ids = token_output.token_ids + logprobs = token_output.log_probs + + # convert the stop reason from verl back to the standard finish reason TODO(listar2000): check backward-compatibility + reason_mapping = {"aborted": "abort", "completed": "stop"} + if token_output.stop_reason is not None: + finish_reason = reason_mapping.get(token_output.stop_reason, token_output.stop_reason) + else: + finish_reason = "stop" + + completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) + # TODO: implement parse_completion for the standard parser + parsed_output = self.chat_parser.parse_completion(completion_ids) + + return ModelOutput( + text=completion_text, + content=parsed_output["content"], + reasoning=parsed_output["reasoning"], + tool_calls=parsed_output["tool_calls"], + prompt_ids=prompt_ids, + completion_ids=completion_ids, + multi_modal_inputs=multi_modal_inputs, + logprobs=logprobs, + prompt_length=prompt_length, + completion_length=len(completion_ids), + finish_reason=finish_reason, + ) diff --git a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py index 3cd314158..ac1d123a4 100644 --- a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py +++ b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py @@ -1,6 +1,6 @@ from rllm.agents.agent import Episode, Trajectory -from rllm.engine.rollout.completer import Completer -from rllm.engine.rollout.rollout_engine import RolloutEngine +from rllm.experimental.rollout.completer import Completer +from rllm.experimental.rollout.rollout_engine import RolloutEngine from rllm.experimental.opsd.workflow_utils import OPSDConfig, opsd_postprocess from rllm.rewards.reward_fn import math_reward_fn from rllm.workflows.workflow import Workflow diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 44013428e..732b02d9d 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -17,7 +17,7 @@ from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset -from rllm.engine.rollout import RolloutEngine +from rllm.experimental.rollout import RolloutEngine from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, diff --git a/rllm/experimental/verl/verl_backend.py b/rllm/experimental/verl/verl_backend.py index 86dc21bd3..1929b9305 100644 --- a/rllm/experimental/verl/verl_backend.py +++ b/rllm/experimental/verl/verl_backend.py @@ -33,7 +33,7 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.engine.rollout import RolloutEngine, VerlEngine +from rllm.experimental.rollout import RolloutEngine, VerlEngine from rllm.experimental.common import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 5401da041..9810063a5 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -23,7 +23,7 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.engine.rollout import RolloutEngine, TinkerEngine +from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.experimental.common import AlgorithmConfig, simple_timer from rllm.experimental.protocol import BackendProtocol from rllm.trainer.tinker.tinker_metrics_utils import ( diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index 7a0d73862..b062f14fa 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -11,8 +11,8 @@ from tinker_cookbook.supervised.common import create_rightshifted_model_input_and_leftshifted_targets from rllm.agents.agent import Trajectory, TrajectoryGroup -from rllm.engine.rollout.tinker_engine import _flat_token_input_length, _flat_token_input_to_model_input -from rllm.engine.rollout.types import TinkerTokenInput +from rllm.experimental.rollout.tinker_engine import _flat_token_input_length, _flat_token_input_to_model_input +from rllm.experimental.rollout.types import TinkerTokenInput from rllm.experimental.common import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups From c86083b694abe16cc4b8f72a89d3682f00eac932 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 12:42:56 -0700 Subject: [PATCH 15/21] revert TinkerChatTemplateParser and parser changes for separate PR Revert parser files to main (tinker_parser.py, conftest, tests, __init__, chat_template_parser, utils). Revert tinker_engine to main's ChatTemplateParser approach, keeping only super().__init__() and _get_model_response rename. Also restore pyproject.toml to main. Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 11 +- rllm/experimental/rollout/tinker_engine.py | 218 +++++++++-- rllm/parser/__init__.py | 35 +- rllm/parser/chat_template_parser.py | 2 +- rllm/parser/tinker_parser.py | 400 --------------------- rllm/parser/utils.py | 11 - tests/parser/conftest.py | 54 --- tests/parser/test_chat_parser.py | 2 +- tests/parser/test_tinker_parser.py | 224 ------------ 9 files changed, 206 insertions(+), 751 deletions(-) delete mode 100644 rllm/parser/tinker_parser.py delete mode 100644 tests/parser/conftest.py delete mode 100644 tests/parser/test_tinker_parser.py diff --git a/pyproject.toml b/pyproject.toml index c1a8622a6..590d5405c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,15 +84,16 @@ dev = [ ] verl = [ - "verl==0.7.1", - "vllm>=0.10.2,<=0.17.0", + "torch", "transformers>=4.55.0,<5.0.0", "numpy", - "torch", - "torchvision", + "verl==0.7.1", + "ray", + "torch>=2.8.0", + "torchvision>=0.23.0", + "vllm>=0.10.2,<=0.12.0", "flash-attn>=2.8.1", "qwen-vl-utils", - "ray", ] sdk = [ diff --git a/rllm/experimental/rollout/tinker_engine.py b/rllm/experimental/rollout/tinker_engine.py index f428230ca..0665777d2 100644 --- a/rllm/experimental/rollout/tinker_engine.py +++ b/rllm/experimental/rollout/tinker_engine.py @@ -1,13 +1,16 @@ -from typing import cast +import json +from typing import Any, cast import tinker from tinker.types import ModelInput from tinker_cookbook import model_info, renderers +from tinker_cookbook.renderers import Message from typing_extensions import override # need to use typing_extensions for python < 3.12 from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine -from rllm.experimental.rollout.types import ImageProcessor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput -from rllm.parser.tinker_parser import TinkerChatTemplateParser +from rllm.experimental.rollout.types import ImageProcessor, Processor, TinkerTokenInput, TinkerTokenOutput, TokenInput, Tokenizer, TokenOutput +from rllm.parser import ChatTemplateParser +from rllm.tools.tool_base import ToolCall from rllm.workflows import TerminationEvent, TerminationReason """ @@ -51,26 +54,115 @@ def _flat_token_input_length(token_input: TokenInput) -> int: return length +def _convert_openai_messages(messages: list[dict[str, Any]]) -> list[Message]: + """Convert OpenAI message dicts to tinker-cookbook Messages. + + Follows the same pattern as tinker_cookbook.third_party.litellm.provider._convert_openai_messages. + TODO: once these helpers are refactored out of the litellm provider into a shared module + (e.g. tinker_cookbook.renderers.openai_compat), import directly instead of duplicating. + """ + from tinker_cookbook.renderers.base import ToolCall as TinkerToolCall + + out: list[Message] = [] + for msg in messages: + tinker_msg: Message = { + "role": msg["role"], + "content": msg.get("content") or "", + } + if "name" in msg: + tinker_msg["name"] = msg["name"] + if "tool_call_id" in msg: + tinker_msg["tool_call_id"] = msg["tool_call_id"] + if "tool_calls" in msg: + tinker_msg["tool_calls"] = [TinkerToolCall.model_validate(tc) for tc in msg["tool_calls"]] + out.append(tinker_msg) + return out + + +def _prepare_messages_with_tools( + renderer: renderers.Renderer, + messages: list[Message], + tools: list[dict[str, Any]], +) -> list[Message]: + """Inject tool declarations into the message list via the renderer. + + Follows the same pattern as tinker_cookbook.third_party.litellm.provider._prepare_messages_with_tools. + TODO: once these helpers are refactored out of the litellm provider into a shared module + (e.g. tinker_cookbook.renderers.openai_compat), import directly instead of duplicating. + """ + from tinker_cookbook.renderers.base import ToolSpec + + tool_specs: list[ToolSpec] = [] + for tool in tools: + if tool.get("type") != "function": + continue + func = tool["function"] + tool_specs.append(ToolSpec(name=func["name"], description=func.get("description", ""), parameters=func.get("parameters", {}))) + + system_prompt = "" + if messages and messages[0]["role"] == "system": + content = messages[0].get("content") or "" + system_prompt = content if isinstance(content, str) else "" + remaining = list(messages[1:]) + else: + remaining = list(messages) + + prefix = renderer.create_conversation_prefix_with_tools(tool_specs, system_prompt) + return prefix + remaining + + +def _parse_tinker_message(message: Message) -> tuple[str, str, list[Any]]: + tinker_content = message["content"] + if isinstance(tinker_content, list): + text_parts, think_parts = [], [] + for part in tinker_content: + if part["type"] == "text": + text_parts.append(part) + elif part["type"] == "thinking": + think_parts.append(part) + content = "\n".join([text["text"] for text in text_parts]) + reasoning = "\n".join([think["thinking"] for think in think_parts]) + else: # no reasoning parsed + content = tinker_content + reasoning = "" + # Convert tinker-cookbook ToolCall (function.name/function.arguments) to rllm ToolCall (name/arguments) + raw_tool_calls = message.get("tool_calls", []) + tool_calls = [] + for tc in raw_tool_calls: + if hasattr(tc, "function"): + # tinker-cookbook ToolCall: ToolCall(function=FunctionBody(name, arguments), id) + args = tc.function.arguments + tool_calls.append(ToolCall(name=tc.function.name, arguments=json.loads(args) if isinstance(args, str) else args)) + elif isinstance(tc, ToolCall): + tool_calls.append(tc) + elif isinstance(tc, dict): + tool_calls.append(ToolCall(name=tc.get("name", ""), arguments=tc.get("arguments", {}))) + else: + raise TypeError(f"Unrecognized tool_call type: {type(tc)}") + return content, reasoning, tool_calls + + class TinkerEngine(RolloutEngine): """ RolloutEngine implementation using Tinker for model inference. - - Wraps the tinker renderer with a TinkerChatTemplateParser, which provides - unified prompt building (including tool spec injection) and response parsing - (content, reasoning, tool_calls). """ def __init__( self, + base_url: str, model_name: str, tokenizer: Tokenizer, service_client: tinker.ServiceClient, - base_url: str | None = None, max_prompt_length: int = 4096, max_response_length: int = 4096, max_model_length: int = 32768, sampling_params: dict | None = None, + bypass_render_with_parser: bool = True, # default to True now + processor: Processor | None = None, image_processor: ImageProcessor | None = None, + disable_thinking: bool = False, + accumulate_reasoning: bool = False, + reasoning_effort: str = "medium", renderer_name: str | None = None, **kwargs, ): @@ -78,18 +170,21 @@ def __init__( Initialize TinkerEngine. Args: + base_url: Tinker service base URL model_name: Name of the model to use tokenizer: Tokenizer for encoding/decoding service_client: Tinker ServiceClient instance - base_url: Tinker service URL (default = null for local) max_prompt_length: Maximum prompt length in tokens max_response_length: Maximum response length in tokens max_model_length: Maximum total length (prompt + response) in tokens sampling_params: Default sampling parameters (temperature, top_p, etc.) + bypass_render_with_parser: If True, use ChatTemplateParser instead of Tinker's renderer + processor: Optional processor for multimodal models (used when bypass_render_with_parser=True) image_processor: Optional image processor for vision-language models (used with renderer) - renderer_name: Optional renderer name to use (None = auto-detect from model) - kwargs: Additional keyword arguments - - strip_thinking_from_history: Whether to strip thinking from history (only for Qwen3Renderer) + disable_thinking: Whether to disable thinking in generation prompt (used when bypass_render_with_parser=True) + accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) + reasoning_effort: The effort level for reasoning (used when bypass_render_with_parser=True) + renderer_name: The name of the renderer to use (used when bypass_render_with_parser=True) """ super().__init__() self.base_url = base_url @@ -98,26 +193,36 @@ def __init__( self.max_response_length = max_response_length self.max_model_length = max_model_length - 1 self.tokenizer = tokenizer + self.bypass_render_with_parser = bypass_render_with_parser + self.accumulate_reasoning = accumulate_reasoning + self.reasoning_effort = reasoning_effort self.train_sampling_params = dict(sampling_params.get("train", {})) if sampling_params else {} self.val_sampling_params = dict(sampling_params.get("val", {})) if sampling_params else {} # Initialize Tinker service client self.service_client = service_client - # Initialize the renderer and wrap with TinkerChatTemplateParser + # Initialize the renderer renderer_name = renderer_name or model_info.get_recommended_renderer_name(self.model_name) - renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) - - if "strip_thinking_from_history" in kwargs and isinstance(kwargs["strip_thinking_from_history"], bool) and hasattr(renderer, "strip_thinking_from_history"): - renderer.strip_thinking_from_history = kwargs["strip_thinking_from_history"] - - self.chat_parser: TinkerChatTemplateParser = TinkerChatTemplateParser(renderer) - self.stop_sequences = self.chat_parser.stop_sequences + # Pass image_processor for VLM support with Tinker renderer + self.renderer = renderers.get_renderer(renderer_name, self.tokenizer, image_processor=image_processor) + + if bypass_render_with_parser: + self.chat_parser = ChatTemplateParser.get_parser(tokenizer, processor=processor, disable_thinking=disable_thinking) + if hasattr(self.chat_parser, "stop_sequences") and self.chat_parser.stop_sequences: + self.stop_sequences = self.chat_parser.stop_sequences + elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + self.stop_sequences = [tokenizer.eos_token_id] + else: + raise ValueError("No stop sequences found for tokenizer or chat parser") + else: + self.chat_parser = None + self.stop_sequences = self.renderer.get_stop_sequences() # Sampling client will be set via set_sampling_client() - self.sampling_client: tinker.SamplingClient | None = None + self.sampling_client = None - def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: + def set_sampling_client(self, sampling_client): """ Set the sampling client for inference. @@ -126,6 +231,25 @@ def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None: """ self.sampling_client = sampling_client + @staticmethod + def _convert_images_to_content_list(messages: list[dict]) -> list[dict]: + """Convert rllm image format to renderer content list format. + + {"content": "text", "images": [PIL.Image]} -> {"content": [ImagePart, TextPart]} + """ + converted = [] + for msg in messages: + if "images" in msg and msg["images"]: + content_list = [] + for img in msg["images"]: + content_list.append({"type": "image", "image": img}) + content_list.append({"type": "text", "text": msg.get("content", "")}) + converted.append({**msg, "content": content_list}) + del converted[-1]["images"] + else: + converted.append(msg) + return converted + def _prepare_max_tokens(self, requested_max_tokens: int, prompt_length: int) -> int: """ Prepare max_tokens parameter, adjusting for max_model_length if needed. @@ -206,11 +330,16 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp sampled_sequence = cast(TinkerTokenOutput, token_output) response_tokens, logprobs = sampled_sequence.tokens, sampled_sequence.logprobs - # Parse response using parser (handles content, reasoning, tool_calls) - parsed_output = self.chat_parser.parse_completion(response_tokens) - content = parsed_output.get("content", "") - reasoning = parsed_output.get("reasoning", "") - tool_calls = parsed_output.get("tool_calls", []) + if self.bypass_render_with_parser: + assert self.chat_parser is not None, "chat_parser must be set when bypass_render_with_parser=True" + parsed_output = self.chat_parser.parse_completion(response_tokens) + content = parsed_output.get("content", "") + reasoning = parsed_output.get("reasoning", "") + tool_calls = parsed_output.get("tool_calls", []) + else: + assert isinstance(self.renderer, renderers.Renderer), "self.renderer must be a valid Tinker Renderer" + response_message, _ = self.renderer.parse_response(response_tokens) + content, reasoning, tool_calls = _parse_tinker_message(response_message) # decode full text completion_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True) # type: ignore @@ -231,11 +360,9 @@ def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutp prompt_ids=prompt_ids, completion_ids=response_tokens, logprobs=logprobs, - routing_matrices=None, prompt_length=_flat_token_input_length(token_input), completion_length=len(response_tokens), finish_reason=finish_reason, - metrics=None, ) @override @@ -248,7 +375,8 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp **kwargs: Additional parameters including: - application_id: Session/application ID for tracing - enforce_max_prompt_length: Whether to enforce max prompt length - - tools: List of tools for tool-augmented generation + - tools: List of tools (used when bypass_render_with_parser=True) + - accumulate_reasoning: Whether to accumulate reasoning (used when bypass_render_with_parser=True) Returns: ModelOutput with generated text and metadata @@ -256,12 +384,32 @@ async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutp # Extract unused kwargs kwargs.pop("application_id", None) - # Extract tools + # Extract parser-specific kwargs tools = kwargs.pop("tools", []) - - # Build prompt using TinkerChatTemplateParser (handles tools, images, etc.) - tinker_prompt = self.chat_parser.build_prompt(messages, tools=tools) - token_input: TinkerTokenInput = tinker_prompt.chunks + accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning) + reasoning_effort = kwargs.pop("reasoning_effort", self.reasoning_effort) + + if self.bypass_render_with_parser: + # Use ChatTemplateParser + prompt = self.chat_parser.parse( # type: ignore + messages, + add_generation_prompt=True, + is_first_msg=True, + tools=tools, + reasoning_effort=reasoning_effort, + accumulate_reasoning=accumulate_reasoning, + ) + token_input = self.tokenizer.encode(prompt, add_special_tokens=False) # type: ignore + else: + # Use Tinker renderer + # Convert images, then convert OpenAI messages to renderer format + converted_messages = self._convert_images_to_content_list(messages) + tinker_messages = _convert_openai_messages(converted_messages) + # Inject tool definitions via renderer if tools are provided + if tools: + tinker_messages = _prepare_messages_with_tools(self.renderer, tinker_messages, tools) + # Build prompt using renderer + token_input: TinkerTokenInput = self.renderer.build_generation_prompt(tinker_messages).chunks # type: ignore sampled_sequence = await self.get_token_output_from_token_input(token_input=token_input, **kwargs) return self.assemble_model_output(token_input=token_input, token_output=sampled_sequence) diff --git a/rllm/parser/__init__.py b/rllm/parser/__init__.py index 8b5e0b993..277acc2eb 100644 --- a/rllm/parser/__init__.py +++ b/rllm/parser/__init__.py @@ -5,13 +5,27 @@ "DeepseekQwenChatTemplateParser", "QwenChatTemplateParser", "LlamaChatTemplateParser", - "TinkerChatTemplateParser", "ToolParser", "R1ToolParser", "QwenToolParser", ] +def __getattr__(name): + _chat_template_classes = { + "ChatTemplateParser", + "DeepseekQwenChatTemplateParser", + "LlamaChatTemplateParser", + "QwenChatTemplateParser", + } + if name in _chat_template_classes: + import importlib + + mod = importlib.import_module("rllm.parser.chat_template_parser") + return getattr(mod, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + PARSER_REGISTRY = { "r1": R1ToolParser, "qwen": QwenToolParser, @@ -21,22 +35,3 @@ def get_tool_parser(parser_name: str) -> type[ToolParser]: assert parser_name in PARSER_REGISTRY, f"Tool parser {parser_name} not found in {PARSER_REGISTRY}" return PARSER_REGISTRY[parser_name] - - -_CHAT_TEMPLATE_CLASSES = { - "ChatTemplateParser", - "DeepseekQwenChatTemplateParser", - "LlamaChatTemplateParser", - "QwenChatTemplateParser", -} - - -def __getattr__(name): - if name in _CHAT_TEMPLATE_CLASSES: - import importlib - mod = importlib.import_module("rllm.parser.chat_template_parser") - return getattr(mod, name) - if name == "TinkerChatTemplateParser": - from rllm.parser.tinker_parser import TinkerChatTemplateParser - return TinkerChatTemplateParser - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/rllm/parser/chat_template_parser.py b/rllm/parser/chat_template_parser.py index f7b053e39..29a312603 100644 --- a/rllm/parser/chat_template_parser.py +++ b/rllm/parser/chat_template_parser.py @@ -703,7 +703,7 @@ def parse_prompt_from_messages(self, messages, add_generation_prompt=False, is_f raise NotImplementedError(f"Unsupported message role: {message['role']}") conv = Conversation.from_messages(harmony_messages) - accumulate_thinking = kwargs.get("accumulate_reasoning", kwargs.get("accumulate_thinking", False)) + accumulate_thinking = kwargs.get("accumulate_thinking", False) config = RenderConversationConfig(auto_drop_analysis=not accumulate_thinking) prompt_ids: list[int] = self.enc.render_conversation(conv, config) diff --git a/rllm/parser/tinker_parser.py b/rllm/parser/tinker_parser.py deleted file mode 100644 index 0bfe5a9dd..000000000 --- a/rllm/parser/tinker_parser.py +++ /dev/null @@ -1,400 +0,0 @@ -import json -import logging - -import torch - -from rllm.parser.chat_template_parser import ChatTemplateParser -from rllm.tools.tool_base import Tool, ToolCall - -logger = logging.getLogger(__name__) - - -try: - import tinker - from tinker.types import ModelInput - from tinker_cookbook.renderers.base import RenderContext, Renderer, TrainOnWhat -except ImportError as e: - raise ImportError("tinker-cookbook and tinker are required for TinkerChatTemplateParser. Install them with: pip install tinker-cookbook tinker") from e - - -def _make_render_context(idx, is_last, prev_message=None, last_user_index=-1): - """Create a RenderContext, handling version differences in tinker-cookbook.""" - try: - return RenderContext( - idx=idx, - is_last=is_last, - prev_message=prev_message, - last_user_index=last_user_index, - ) - except TypeError: - # Older tinker-cookbook without last_user_index field - return RenderContext(idx=idx, is_last=is_last, prev_message=prev_message) - - -class TinkerChatTemplateParser(ChatTemplateParser): - """ChatTemplateParser that delegates to a tinker-cookbook Renderer. - - This allows users who have tinker-cookbook installed to use any tinker - renderer through rllm's ChatTemplateParser interface, avoiding the need - to write a manual parser for each model family. - - Example:: - - from tinker_cookbook import renderers, tokenizer_utils - from rllm.parser import TinkerChatTemplateParser - - tokenizer = tokenizer_utils.get_tokenizer("Qwen/Qwen3-8B") - renderer = renderers.get_renderer("qwen3", tokenizer) - parser = TinkerChatTemplateParser(renderer) - - prompt = parser.parse(messages, add_generation_prompt=True, is_first_msg=True) - """ - - def __init__(self, renderer: Renderer) -> None: - if not isinstance(renderer, Renderer): - raise TypeError(f"Expected a tinker_cookbook Renderer, got {type(renderer)}") - self.renderer = renderer - self.tokenizer = renderer.tokenizer - self.processor = None - - # Compute generation_prompt by decoding the generation suffix tokens - ctx = _make_render_context(idx=0, is_last=True) - suffix_tokens = self.renderer._get_generation_suffix("assistant", ctx) - self.generation_prompt = self.tokenizer.decode(suffix_tokens) if suffix_tokens else "" - - self.stop_sequences = self.renderer.get_stop_sequences() - - def _convert_message(self, msg: dict) -> dict: - """Convert an rllm message dict to a tinker Message dict.""" - tinker_msg = {"role": msg["role"]} - - content = msg.get("content", "") or "" - reasoning = (msg.get("reasoning", "") or "").strip() - - # Build structured content when reasoning or images are present - if reasoning: - parts = [] - parts.append({"type": "thinking", "thinking": reasoning}) - if content: - parts.append({"type": "text", "text": content}) - tinker_msg["content"] = parts - elif isinstance(msg.get("images"), list) and msg["images"]: - parts = [] - for img in msg["images"]: - parts.append({"type": "image", "image": img}) - if content: - # Strip leading tag if present (rllm convention) - if content.startswith(""): - content = content[len("") :] - parts.append({"type": "text", "text": content}) - tinker_msg["content"] = parts - else: - tinker_msg["content"] = content - - # Convert tool_calls to tinker ToolCall format - if msg.get("tool_calls"): - from tinker_cookbook.renderers.base import ToolCall as TinkerToolCall - - tool_calls = [] - for tc in msg["tool_calls"]: - if isinstance(tc, ToolCall): - # rllm ToolCall dataclass - args = tc.arguments if isinstance(tc.arguments, str) else json.dumps(tc.arguments) - tool_calls.append( - TinkerToolCall( - function=TinkerToolCall.FunctionBody(name=tc.name, arguments=args), - ) - ) - elif isinstance(tc, dict) and "function" in tc: - func = tc["function"] - args = func.get("arguments", "{}") - if not isinstance(args, str): - args = json.dumps(args) - tool_calls.append( - TinkerToolCall( - function=TinkerToolCall.FunctionBody(name=func["name"], arguments=args), - id=tc.get("id"), - ) - ) - elif isinstance(tc, dict) and "name" in tc: - args = tc.get("arguments", "{}") - if not isinstance(args, str): - args = json.dumps(args) - tool_calls.append( - TinkerToolCall( - function=TinkerToolCall.FunctionBody(name=tc["name"], arguments=args), - id=tc.get("id"), - ) - ) - if tool_calls: - tinker_msg["tool_calls"] = tool_calls - - # Handle tool response fields - if msg["role"] == "tool": - if "tool_call_id" in msg: - tinker_msg["tool_call_id"] = msg["tool_call_id"] - if "name" in msg: - tinker_msg["name"] = msg["name"] - - return tinker_msg - - def _convert_messages(self, messages: list[dict]) -> list[dict]: - """Convert a list of rllm message dicts to tinker Message format.""" - return [self._convert_message(m) for m in messages] - - def _convert_tools(self, tools: list[Tool | dict]) -> list[dict]: - """Convert rllm tools to tinker ToolSpec format.""" - tool_specs = [] - for tool in tools: - if isinstance(tool, Tool): - # rllm Tool object - extract from json property - tool_json = tool.json - if "function" in tool_json: - func = tool_json["function"] - tool_specs.append( - { - "name": func["name"], - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - } - ) - elif isinstance(tool, dict): - if "function" in tool: - func = tool["function"] - tool_specs.append( - { - "name": func["name"], - "description": func.get("description", ""), - "parameters": func.get("parameters", {}), - } - ) - elif "name" in tool: - tool_specs.append( - { - "name": tool["name"], - "description": tool.get("description", ""), - "parameters": tool.get("parameters", {}), - } - ) - return tool_specs - - def _render_to_tokens(self, tinker_messages: list[dict], add_bos: bool = False, add_generation_prompt: bool = False) -> list[int]: - """Render tinker messages to a flat list of token IDs.""" - - chunks = [] - - if add_bos and self.renderer._bos_tokens: - chunks.append(tinker.EncodedTextChunk(tokens=self.renderer._bos_tokens)) - - last_user_idx = max( - (i for i, m in enumerate(tinker_messages) if m["role"] == "user"), - default=-1, - ) - - for idx, msg in enumerate(tinker_messages): - ctx = _make_render_context( - idx=idx, - is_last=(idx == len(tinker_messages) - 1) and not add_generation_prompt, - prev_message=tinker_messages[idx - 1] if idx > 0 else None, - last_user_index=last_user_idx, - ) - rendered = self.renderer.render_message(msg, ctx) - if rendered.header: - chunks.append(rendered.header) - chunks.extend(x for x in rendered.output if not isinstance(x, tinker.EncodedTextChunk) or x.tokens) - - if add_generation_prompt: - suffix_ctx = _make_render_context( - idx=len(tinker_messages), - is_last=True, - prev_message=tinker_messages[-1] if tinker_messages else None, - last_user_index=last_user_idx, - ) - suffix_tokens = self.renderer._get_generation_suffix("assistant", suffix_ctx) - if suffix_tokens: - chunks.append(tinker.EncodedTextChunk(tokens=suffix_tokens)) - - # Flatten chunks to token list - tokens = [] - for chunk in chunks: - if isinstance(chunk, tinker.EncodedTextChunk): - tokens.extend(chunk.tokens) - else: - # ImageChunk or other non-token chunk - use length as placeholder - # This path is for VL models; decode will produce placeholder tokens - tokens.extend([0] * chunk.length) - - return tokens - - def _prepare_messages(self, messages: list[dict], tools: list[Tool | dict] | None = None) -> list[dict]: - """Convert rllm messages to tinker format and prepend tool context if needed. - - Args: - messages: List of rllm message dicts. - tools: Optional list of tools to include in the system prompt. - - Returns: - List of tinker-format message dicts ready for rendering. - """ - tinker_messages = self._convert_messages(messages) - - if tools: - tool_specs = self._convert_tools(tools) - if tool_specs: - try: - system_prompt = "" - if tinker_messages and tinker_messages[0]["role"] == "system": - content = tinker_messages[0]["content"] - if isinstance(content, str): - system_prompt = content - tinker_messages = tinker_messages[1:] - prefix = self.renderer.create_conversation_prefix_with_tools(tool_specs, system_prompt) - tinker_messages = prefix + tinker_messages - except NotImplementedError: - logger.warning(f"Renderer {type(self.renderer).__name__} does not support tool calling. Tools will be ignored.") - - return tinker_messages - - def build_prompt(self, messages: list[dict], tools: list[Tool | dict] | None = None) -> ModelInput: - """Build a ModelInput prompt from messages, preserving image chunks for VLM. - - Unlike parse() which decodes to a string, this returns a ModelInput directly - via the renderer's build_generation_prompt, avoiding the token->string->token - round-trip and preserving ImageChunks for vision-language models. - - Args: - messages: List of rllm message dicts. - tools: Optional list of tools to include in the prompt. - - Returns: - tinker ModelInput with generation prompt appended. - """ - tinker_messages = self._prepare_messages(messages, tools=tools) - return self.renderer.build_generation_prompt(tinker_messages) - - def parse(self, messages: list[dict], add_generation_prompt: bool = False, is_first_msg: bool = False, tools: list[Tool | dict] | None = None, **kwargs) -> str: - """Parse messages into a prompt string. - - Note: For TinkerEngine, prefer build_prompt() which returns a ModelInput - directly and preserves image chunks. This method is for compatibility with - non-Tinker rollout engines. - - Args: - messages: List of rllm message dicts. - add_generation_prompt: Whether to append the generation prompt. - is_first_msg: Whether this is the first message (adds BOS token). - tools: Optional list of tools to include in the prompt. - - Returns: - The rendered prompt string. - """ - if not messages: - return "" - - tinker_messages = self._prepare_messages(messages, tools=tools) - - tokens = self._render_to_tokens(tinker_messages, add_bos=is_first_msg, add_generation_prompt=add_generation_prompt) - result = self.tokenizer.decode(tokens, skip_special_tokens=False) - - # Tinker puts the \n separator in the next message's header, so the last - # message lacks a trailing \n. HF templates always include it. Add it to - # match HF's apply_chat_template output. - if result and not result.endswith("\n"): - result += "\n" - - return result - - def parse_completion(self, completion_ids: list[int]) -> dict[str, str | list]: - """Parse completion token IDs into structured output. - - Args: - completion_ids: List of token IDs from model generation. - - Returns: - Dict with 'content', 'reasoning', and 'tool_calls' keys. - """ - parsed_msg, _success = self.renderer.parse_response(completion_ids) - - content = "" - reasoning = "" - tool_calls = [] - - msg_content = parsed_msg.get("content", "") - if isinstance(msg_content, str): - content = msg_content - elif isinstance(msg_content, list): - text_parts = [] - thinking_parts = [] - for part in msg_content: - if part["type"] == "text": - text_parts.append(part["text"]) - elif part["type"] == "thinking": - thinking_parts.append(part["thinking"]) - content = "".join(text_parts) - reasoning = "".join(thinking_parts) - - # Convert tinker ToolCall objects to rllm ToolCall dataclass - if parsed_msg.get("tool_calls"): - for tc in parsed_msg["tool_calls"]: - try: - args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, TypeError): - args = tc.function.arguments - tool_calls.append(ToolCall(name=tc.function.name, arguments=args)) - - return { - "content": content.strip(), - "reasoning": reasoning.strip(), - "tool_calls": tool_calls, - } - - def tokenize_and_mask(self, messages): - """Convert messages to token IDs with loss masks using tinker's supervised example builder. - - Returns: - Tuple of (prompt_ids, response_ids, response_mask) as torch tensors. - """ - tinker_messages = self._convert_messages(messages) - model_input, weights = self.renderer.build_supervised_example(tinker_messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE) - - all_tokens = model_input.to_ints() - weights_list = weights.tolist() - - # Split at first non-zero weight - boundary = next((i for i, w in enumerate(weights_list) if w > 0), len(weights_list)) - - prompt_ids = torch.tensor(all_tokens[:boundary], dtype=torch.long) - response_ids = torch.tensor(all_tokens[boundary:], dtype=torch.long) - response_mask = weights[boundary:].long() - - return prompt_ids, response_ids, response_mask - - def tokenize_and_mask_cumulative(self, messages): - """Convert multi-turn messages to token IDs with cumulative loss masks. - - Returns: - Tuple of (prompt_ids, response_ids, response_mask) as torch tensors. - """ - tinker_messages = self._convert_messages(messages) - model_input, weights = self.renderer.build_supervised_example(tinker_messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES) - - all_tokens = model_input.to_ints() - weights_list = weights.tolist() - - # Split at first non-zero weight - boundary = next((i for i, w in enumerate(weights_list) if w > 0), len(weights_list)) - - prompt_ids = torch.tensor(all_tokens[:boundary], dtype=torch.long) - response_ids = torch.tensor(all_tokens[boundary:], dtype=torch.long) - response_mask = weights[boundary:].long() - - return prompt_ids, response_ids, response_mask - - def verify_equivalence(self, messages, verbose=True): - """Tinker renderers handle token-level correctness by design. - - NOTE(listar2000): the `verify_equivalence` test from parent does not make too much sense. - Instead of checking equivalence with HF templates, it check single versus multiple message parsing. - So it makes sense for the tinker parser to not pass this test. We simply return True here. - """ - return True diff --git a/rllm/parser/utils.py b/rllm/parser/utils.py index 61f52d40e..e255b04ba 100644 --- a/rllm/parser/utils.py +++ b/rllm/parser/utils.py @@ -6,14 +6,3 @@ {"role": "user", "content": "What about Java?"}, {"role": "assistant", "content": "Let me search for Java information.", "tool_calls": [{"function": {"name": "search", "arguments": '{"query": "Java programming"}'}}]}, ] - -# Simple multi-turn messages for verify_equivalence tests. -# Ends with a user message (representing the prompt before model generation) -# to avoid HF template quirks like Qwen3's tag insertion on the last -# assistant message after the last user query. -SIMPLE_TEST_MESSAGES = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you! How can I help you today?"}, - {"role": "user", "content": "What is the capital of France?"}, -] diff --git a/tests/parser/conftest.py b/tests/parser/conftest.py deleted file mode 100644 index 8ab875f00..000000000 --- a/tests/parser/conftest.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Parser tests require real packages (transformers, pydantic, torch, etc.). - -The root conftest.py stubs out heavy optional dependencies for lightweight unit -tests. This conftest removes the specific stubs so parser integration tests can -use real packages. -""" - -import sys -import types - -# These are the exact modules stubbed by root conftest.py _STUB_MODULES list, -# plus the additional stubs it creates for sub-modules and fake classes. -_ROOT_STUB_MODULES = [ - "numpy", - "httpx", - "transformers", - "datasets", - "ray", - "pandas", - "polars", - "sympy", - "pylatexenc", - "antlr4", - "antlr4_python3_runtime", - "mcp", - "eval_protocol", - "hydra", - "fastapi", - "uvicorn", - "tqdm", - "yaml", - "pydantic", - "wrapt", - "asgiref", - "wandb", - "codetiming", - "click", - # Also stubbed explicitly by root conftest - "torch", - "PIL", - "openai", -] - -# Remove stub modules and any sub-modules created by root conftest -_to_remove = [] -for name in list(sys.modules): - base = name.split(".")[0] - if base in _ROOT_STUB_MODULES: - mod = sys.modules[name] - if isinstance(mod, types.ModuleType) and not hasattr(mod, "__file__"): - _to_remove.append(name) - -for name in _to_remove: - del sys.modules[name] diff --git a/tests/parser/test_chat_parser.py b/tests/parser/test_chat_parser.py index 4bac5428f..d45c7fdd8 100644 --- a/tests/parser/test_chat_parser.py +++ b/tests/parser/test_chat_parser.py @@ -73,7 +73,7 @@ def test_parser_with_disable_thinking(): parser = QwenChatTemplateParser(tokenizer, disable_thinking=True) # Verify that thinking is disabled in the generation prompt - assert "\n\n\n\n" in parser.assistant_token + assert "\\n\\n\\n\\n" in parser.assistant_token # Test equivalence check assert parser.verify_equivalence(PARSER_TEST_MESSAGES) diff --git a/tests/parser/test_tinker_parser.py b/tests/parser/test_tinker_parser.py deleted file mode 100644 index 54dbedea0..000000000 --- a/tests/parser/test_tinker_parser.py +++ /dev/null @@ -1,224 +0,0 @@ -import sys -from unittest.mock import patch - -import pytest -from tinker_cookbook import renderers -from transformers import AutoTokenizer - -from rllm.parser import QwenChatTemplateParser -from rllm.parser.tinker_parser import TinkerChatTemplateParser -from rllm.parser.utils import SIMPLE_TEST_MESSAGES - - -@pytest.fixture -def qwen_tokenizer(): - return AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") - - -@pytest.fixture -def qwen_renderer(qwen_tokenizer): - return renderers.get_renderer("qwen3", qwen_tokenizer) - - -@pytest.fixture -def qwen_tinker_parser(qwen_renderer): - return TinkerChatTemplateParser(qwen_renderer) - - -def test_tinker_parser_init(qwen_tinker_parser): - """Verify that constructor sets up generation_prompt and stop_sequences.""" - assert qwen_tinker_parser.generation_prompt - assert isinstance(qwen_tinker_parser.generation_prompt, str) - assert qwen_tinker_parser.stop_sequences is not None - assert qwen_tinker_parser.tokenizer is not None - assert qwen_tinker_parser.renderer is not None - - -def test_tinker_parser_init_bad_renderer(): - """Verify TypeError when passing a non-renderer object.""" - with pytest.raises(TypeError, match="Expected a tinker_cookbook Renderer"): - TinkerChatTemplateParser("not a renderer") - - -def test_tinker_parser_parse(qwen_tinker_parser): - """Verify parse() returns a valid non-empty string.""" - result = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=True, is_first_msg=True) - assert isinstance(result, str) - assert len(result) > 0 - - -def test_tinker_parser_parse_empty(): - """Verify parse([]) returns empty string.""" - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") - renderer = renderers.get_renderer("qwen3", tokenizer) - parser = TinkerChatTemplateParser(renderer) - assert parser.parse([]) == "" - - -def test_tinker_parser_parse_generation_prompt(qwen_tinker_parser): - """Verify that generation prompt is appended when requested.""" - with_prompt = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=True, is_first_msg=True) - without_prompt = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, add_generation_prompt=False, is_first_msg=True) - # The version with generation prompt should be longer - assert len(with_prompt) > len(without_prompt) - - -def test_tinker_parser_parse_is_first_msg(qwen_tinker_parser): - """Verify is_first_msg controls BOS token inclusion.""" - with_bos = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, is_first_msg=True) - without_bos = qwen_tinker_parser.parse(SIMPLE_TEST_MESSAGES, is_first_msg=False) - # With BOS should be at least as long as without - assert len(with_bos) >= len(without_bos) - - -def test_tinker_parser_parse_with_reasoning(qwen_tinker_parser): - """Verify that reasoning is included when accumulate_reasoning=True.""" - messages = [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there", "reasoning": "The user greeted me"}, - ] - with_reasoning = qwen_tinker_parser.parse(messages, accumulate_reasoning=True, is_first_msg=True) - without_reasoning = qwen_tinker_parser.parse(messages, accumulate_reasoning=False, is_first_msg=True) - assert "think" in with_reasoning or len(with_reasoning) > len(without_reasoning) - - -def test_tinker_parser_parse_completion(qwen_tinker_parser, qwen_tokenizer): - """Verify parse_completion returns correct structure.""" - # Encode a proper assistant response with thinking + end token. - # The renderer expects tokens as if produced by the model during generation, - # which means they must end with the stop sequence (<|im_end|> for Qwen3). - text = "\nLet me think about this.\n\n\nHello, how can I help?<|im_end|>" - token_ids = qwen_tokenizer.encode(text, add_special_tokens=False) - - result = qwen_tinker_parser.parse_completion(token_ids) - - assert isinstance(result, dict) - assert "content" in result - assert "reasoning" in result - assert "tool_calls" in result - assert isinstance(result["tool_calls"], list) - # The thinking should be extracted as reasoning - assert result["reasoning"] - assert "think" in result["reasoning"].lower() - assert "Hello" in result["content"] - - -def test_tinker_parser_tokenize_and_mask(qwen_tinker_parser): - """Verify tokenize_and_mask returns correct tensor shapes and mask values.""" - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - ] - prompt_ids, response_ids, response_mask = qwen_tinker_parser.tokenize_and_mask(messages) - - assert prompt_ids.dim() == 1 - assert response_ids.dim() == 1 - assert response_mask.dim() == 1 - assert len(response_ids) == len(response_mask) - assert len(prompt_ids) > 0 - assert len(response_ids) > 0 - # Response mask should have non-zero values - assert response_mask.sum() > 0 - - -def test_tinker_parser_tokenize_and_mask_cumulative(qwen_tinker_parser): - """Verify tokenize_and_mask_cumulative returns correct tensor shapes.""" - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - {"role": "user", "content": "And 3+3?"}, - {"role": "assistant", "content": "6"}, - ] - prompt_ids, response_ids, response_mask = qwen_tinker_parser.tokenize_and_mask_cumulative(messages) - - assert prompt_ids.dim() == 1 - assert response_ids.dim() == 1 - assert response_mask.dim() == 1 - assert len(response_ids) == len(response_mask) - assert len(prompt_ids) > 0 - assert len(response_ids) > 0 - # Both assistant responses should be masked - assert response_mask.sum() > 0 - # Should have some zero-masked tokens (user message between assistants) - assert (response_mask == 0).any() - - -def test_tinker_parser_verify_equivalence(qwen_tinker_parser): - """Tinker parser should always return True for verify_equivalence.""" - assert qwen_tinker_parser.verify_equivalence(SIMPLE_TEST_MESSAGES) is True - - -def test_tinker_parser_matches_manual_qwen(qwen_tokenizer): - """Compare TinkerChatTemplateParser output with QwenChatTemplateParser for simple messages.""" - renderer = renderers.get_renderer("qwen3", qwen_tokenizer) - tinker_parser = TinkerChatTemplateParser(renderer) - manual_parser = QwenChatTemplateParser(qwen_tokenizer) - - # Simple messages without tool calls (avoid tool call format differences) - simple_messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, - ] - - tinker_result = tinker_parser.parse(simple_messages, add_generation_prompt=False, is_first_msg=True) - manual_result = manual_parser.parse(simple_messages, add_generation_prompt=False, is_first_msg=True) - - # Tokenize both and compare token sequences (more robust than string comparison - # because decode round-trip may differ in whitespace/special token rendering). - # Strip trailing whitespace since HF templates add \n after <|im_end|> but - # tinker's token-level rendering does not. - tinker_tokens = qwen_tokenizer.encode(tinker_result.rstrip(), add_special_tokens=False) - manual_tokens = qwen_tokenizer.encode(manual_result.rstrip(), add_special_tokens=False) - assert tinker_tokens == manual_tokens - - -def test_tinker_parser_message_conversion(qwen_tinker_parser): - """Test that message conversion handles various message formats.""" - messages = [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hi"}, - { - "role": "assistant", - "content": "Let me search.", - "tool_calls": [{"function": {"name": "search", "arguments": '{"q": "test"}'}}], - }, - ] - converted = qwen_tinker_parser._convert_messages(messages) - assert len(converted) == 3 - assert converted[0]["role"] == "system" - assert converted[1]["role"] == "user" - assert converted[2]["role"] == "assistant" - - -def test_import_error_without_tinker(): - """Verify helpful ImportError when tinker-cookbook is not installed.""" - # The module-level import in tinker_parser.py raises ImportError if tinker-cookbook - # is not installed. Since the module is already imported, we verify the error message - # by checking the module-level try/except pattern exists. - import importlib - - saved_modules = {} - modules_to_remove = [key for key in sys.modules if key.startswith(("tinker_cookbook", "tinker"))] - # Also remove the cached tinker_parser module so it can be re-imported - if "rllm.parser.tinker_parser" in sys.modules: - saved_modules["rllm.parser.tinker_parser"] = sys.modules.pop("rllm.parser.tinker_parser") - for key in modules_to_remove: - saved_modules[key] = sys.modules.pop(key) - - try: - with patch.dict( - sys.modules, - { - "tinker_cookbook": None, - "tinker_cookbook.renderers": None, - "tinker_cookbook.renderers.base": None, - "tinker": None, - }, - ): - with pytest.raises(ImportError, match="tinker-cookbook and tinker are required"): - importlib.import_module("rllm.parser.tinker_parser") - finally: - sys.modules.update(saved_modules) From a5b8b4f4cca456dc80472c0e74eea15e6d49a0c9 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 12:51:52 -0700 Subject: [PATCH 16/21] revert bypass_render_with_parser and tinker parser-related changes Revert config, docs, examples, and rollout files that referenced bypass_render_with_parser (now staying in tinker_engine since we reverted to main's ChatTemplateParser approach). Clean up tinker_backend to only retain async-related changes. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/experimental/rllm-and-backend-config.md | 1 + examples/countdown/train_countdown_distill_tinker.sh | 1 + examples/math_distill/opsd/train_deepmath_distill_tinker.sh | 1 + examples/math_distill/train_deepmath_distill_tinker.py | 1 + examples/math_distill/train_deepmath_distill_tinker.sh | 1 + rllm/experimental/config/rllm/backend/tinker.yaml | 5 ++++- rllm/experimental/rollout/completer.py | 6 ++++-- rllm/experimental/rollout/types.py | 3 +-- rllm/experimental/test_examples/opsd/math_opsd_workflow.py | 2 +- rllm/trainer/config/tinker_rl_trainer.yaml | 1 + rllm/trainer/tinker/tinker_backend.py | 4 +--- 11 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/experimental/rllm-and-backend-config.md b/docs/experimental/rllm-and-backend-config.md index 2f9fc8de5..4baebd341 100644 --- a/docs/experimental/rllm-and-backend-config.md +++ b/docs/experimental/rllm-and-backend-config.md @@ -238,6 +238,7 @@ This file contains: | `rollout_engine.reasoning_effort` | `str` | `medium` | Reasoning effort mode | | `rollout_engine.accumulate_reasoning` | `bool` | `false` | Whether to accumulate reasoning across steps | | `rollout_engine.disable_thinking` | `bool` | `false` | Whether to disable thinking tokens | +| `rollout_engine.bypass_render_with_parser` | `bool` | `false` | Whether to bypass render parsing | | `rollout_engine.renderer_name` | `str | null` | `null` | Optional renderer name | | `data.max_prompt_length` | `int` | `2048` | Max prompt length | | `data.max_response_length` | `int` | `2048` | Max response length | diff --git a/examples/countdown/train_countdown_distill_tinker.sh b/examples/countdown/train_countdown_distill_tinker.sh index 1107a312d..7b3a17d5f 100644 --- a/examples/countdown/train_countdown_distill_tinker.sh +++ b/examples/countdown/train_countdown_distill_tinker.sh @@ -24,3 +24,4 @@ python -m examples.countdown.train_countdown_tinker \ trainer.test_freq=10 \ trainer.save_freq=1000 \ trainer.default_local_dir='./outputs/countdown-distill-tinker-8b' \ + rollout_engine.bypass_render_with_parser=True diff --git a/examples/math_distill/opsd/train_deepmath_distill_tinker.sh b/examples/math_distill/opsd/train_deepmath_distill_tinker.sh index 43c2c74ed..cf3a8492d 100644 --- a/examples/math_distill/opsd/train_deepmath_distill_tinker.sh +++ b/examples/math_distill/opsd/train_deepmath_distill_tinker.sh @@ -25,4 +25,5 @@ python -m examples.math_distill.opsd.train_deepmath_distill_tinker \ training.default_local_dir='./outputs/opsd-deepmath-8b-rllm' \ rllm.algorithm.use_precomputed_advantage=true \ rllm.algorithm.loss_fn=importance_sampling \ + rollout_engine.bypass_render_with_parser=True \ rllm.workflow.n_parallel_tasks=512 diff --git a/examples/math_distill/train_deepmath_distill_tinker.py b/examples/math_distill/train_deepmath_distill_tinker.py index fb2628721..d4dc5f343 100644 --- a/examples/math_distill/train_deepmath_distill_tinker.py +++ b/examples/math_distill/train_deepmath_distill_tinker.py @@ -26,6 +26,7 @@ def main(config: DictConfig): tokenizer=teacher_tokenizer, service_client=teacher_service_client, sampling_client=teacher_sampling_client, + bypass_render_with_parser=True, ) trainer = AgentTrainer( diff --git a/examples/math_distill/train_deepmath_distill_tinker.sh b/examples/math_distill/train_deepmath_distill_tinker.sh index 69a769592..26efe10dc 100644 --- a/examples/math_distill/train_deepmath_distill_tinker.sh +++ b/examples/math_distill/train_deepmath_distill_tinker.sh @@ -25,5 +25,6 @@ python -m examples.math_distill.train_deepmath_distill_tinker \ training.default_local_dir='./outputs/deepmath-distill-8b-32b-unified' \ rllm.algorithm.use_precomputed_advantage=true \ rllm.algorithm.loss_fn=importance_sampling \ + rollout_engine.bypass_render_with_parser=False \ rollout_engine.renderer_name=qwen3 \ rllm.workflow.n_parallel_tasks=512 diff --git a/rllm/experimental/config/rllm/backend/tinker.yaml b/rllm/experimental/config/rllm/backend/tinker.yaml index 549df9d21..225106d7e 100644 --- a/rllm/experimental/config/rllm/backend/tinker.yaml +++ b/rllm/experimental/config/rllm/backend/tinker.yaml @@ -59,7 +59,10 @@ agent: # Tinker Engine Configuration rollout_engine: - strip_thinking_from_history: true + reasoning_effort: "medium" + accumulate_reasoning: false + disable_thinking: false + bypass_render_with_parser: false renderer_name: null # Data Configuration diff --git a/rllm/experimental/rollout/completer.py b/rllm/experimental/rollout/completer.py index 8aa034124..8c818db90 100644 --- a/rllm/experimental/rollout/completer.py +++ b/rllm/experimental/rollout/completer.py @@ -16,7 +16,9 @@ from rllm.agents.agent import Step from rllm.experimental.rollout.rollout_engine import ModelOutput, RolloutEngine from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput -from rllm.parser import ChatTemplateParser + +if TYPE_CHECKING: + from rllm.parser import ChatTemplateParser class Completer: @@ -86,7 +88,7 @@ def __init__(self, rollout_engine: RolloutEngine): raise ValueError(f"The rollout engine {cls_name} does not support token-in-token-out") # we also require the rollout engine has a chat parser and a tokenizer if rollout_engine.chat_parser is None or rollout_engine.tokenizer is None: - raise ValueError("The rollout engine must have a chat parser and a tokenizer.") + raise ValueError("The rollout engine must have a chat parser and a tokenizer. For Tinker engine, make sure you have set bypass_render_with_parser=True.") self.tokenizer = rollout_engine.tokenizer self.chat_parser = rollout_engine.chat_parser diff --git a/rllm/experimental/rollout/types.py b/rllm/experimental/rollout/types.py index d52466d2d..22b30195b 100644 --- a/rllm/experimental/rollout/types.py +++ b/rllm/experimental/rollout/types.py @@ -17,8 +17,7 @@ Processor: TypeAlias = Any ImageProcessor: TypeAlias = Any -# Tinker types. -# See https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py +# Tinker types. See https://github.com/thinking-machines-lab/tinker-cookbook/blob/main/tinker_cookbook/rl/data_processing.py # for the rationale behind "FlatObElem" and "FlatOb" types. try: from tinker.types import ModelInputChunk, SampledSequence diff --git a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py index ac1d123a4..8f9488591 100644 --- a/rllm/experimental/test_examples/opsd/math_opsd_workflow.py +++ b/rllm/experimental/test_examples/opsd/math_opsd_workflow.py @@ -1,7 +1,7 @@ from rllm.agents.agent import Episode, Trajectory +from rllm.experimental.opsd.workflow_utils import OPSDConfig, opsd_postprocess from rllm.experimental.rollout.completer import Completer from rllm.experimental.rollout.rollout_engine import RolloutEngine -from rllm.experimental.opsd.workflow_utils import OPSDConfig, opsd_postprocess from rllm.rewards.reward_fn import math_reward_fn from rllm.workflows.workflow import Workflow diff --git a/rllm/trainer/config/tinker_rl_trainer.yaml b/rllm/trainer/config/tinker_rl_trainer.yaml index 8862068a6..95630a37c 100644 --- a/rllm/trainer/config/tinker_rl_trainer.yaml +++ b/rllm/trainer/config/tinker_rl_trainer.yaml @@ -69,6 +69,7 @@ rollout_engine: reasoning_effort: "medium" accumulate_reasoning: false disable_thinking: false + bypass_render_with_parser: false renderer_name: null # Override renderer name (null = auto-detect from model) # Data Configuration diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index 9810063a5..e0689ff7d 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -23,9 +23,9 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.experimental.common import AlgorithmConfig, simple_timer from rllm.experimental.protocol import BackendProtocol +from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.trainer.tinker.tinker_metrics_utils import ( print_metrics_table, update_training_metrics, @@ -117,8 +117,6 @@ def init_rollout_engine(self, **kwargs) -> RolloutEngine: Args: **kwargs: Additional arguments, including the various configurations - - strip_thinking_from_history: Whether to strip thinking from history (default = true) - - renderer_name: Name of the renderer to use (default = auto-detect from model) Returns: TinkerEngine: The initialized rollout engine. From 4b67829e0e5ba533c65d619a183696691209b425 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 13:47:46 -0700 Subject: [PATCH 17/21] remove engine/gateway-level gate mechanism The per-request gate on RolloutEngine is unnecessary: - partial_rollout=True: verl handles abort/resume at server level, Tinker hot-swaps weights in place - partial_rollout=False: coordination happens at task dispatch level (coordinator pause/resume), not per-request Remove close_gate/open_gate/wait_for_gate/wait_for_drain from RolloutEngine, GatewayManager, and model-gateway proxy/server/client. Remove needs_weight_sync_gate from BackendProtocol. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/rllm_model_gateway/client.py | 26 --------- .../src/rllm_model_gateway/proxy.py | 43 +-------------- .../src/rllm_model_gateway/server.py | 15 ----- rllm/experimental/engine/gateway_manager.py | 14 ----- rllm/experimental/protocol.py | 1 - rllm/experimental/rollout/rollout_engine.py | 55 +------------------ rllm/experimental/unified_trainer.py | 30 +--------- rllm/trainer/tinker/tinker_backend.py | 1 - 8 files changed, 7 insertions(+), 178 deletions(-) diff --git a/rllm-model-gateway/src/rllm_model_gateway/client.py b/rllm-model-gateway/src/rllm_model_gateway/client.py index e1d1a20bd..8dffab091 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/client.py +++ b/rllm-model-gateway/src/rllm_model_gateway/client.py @@ -140,19 +140,6 @@ def health(self) -> dict[str, Any]: resp.raise_for_status() return resp.json() - # -- Gate (weight sync) ------------------------------------------------ - - def close_gate(self) -> None: - resp = self._http.post(f"{self.gateway_url}/admin/gate/close") - resp.raise_for_status() - - def open_gate(self) -> None: - resp = self._http.post(f"{self.gateway_url}/admin/gate/open") - resp.raise_for_status() - - def wait_for_drain(self, timeout: float | None = None) -> None: - resp = self._http.post(f"{self.gateway_url}/admin/gate/drain", timeout=timeout) - resp.raise_for_status() class AsyncGatewayClient: @@ -281,16 +268,3 @@ async def health(self) -> dict[str, Any]: resp.raise_for_status() return resp.json() - # -- Gate (weight sync) ------------------------------------------------ - - async def close_gate(self) -> None: - resp = await self._http.post(f"{self.gateway_url}/admin/gate/close") - resp.raise_for_status() - - async def open_gate(self) -> None: - resp = await self._http.post(f"{self.gateway_url}/admin/gate/open") - resp.raise_for_status() - - async def wait_for_drain(self, timeout: float | None = None) -> None: - resp = await self._http.post(f"{self.gateway_url}/admin/gate/drain", timeout=timeout) - resp.raise_for_status() diff --git a/rllm-model-gateway/src/rllm_model_gateway/proxy.py b/rllm-model-gateway/src/rllm_model_gateway/proxy.py index 1e8a9e38b..00429746a 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/proxy.py +++ b/rllm-model-gateway/src/rllm_model_gateway/proxy.py @@ -89,13 +89,6 @@ def __init__( self._http: httpx.AsyncClient | None = None self._pending_traces: set[asyncio.Task[None]] = set() - # Gate for weight sync: when closed, new requests wait; in-flight requests finish. - self._gate = asyncio.Event() - self._gate.set() # open by default - self._active_requests: int = 0 - self._drained = asyncio.Event() - self._drained.set() - async def start(self) -> None: self._http = httpx.AsyncClient( timeout=httpx.Timeout(timeout=None), # no timeout — LLM calls can be long @@ -113,35 +106,6 @@ async def stop(self) -> None: await self._http.aclose() self._http = None - # ------------------------------------------------------------------ - # Gate (weight sync) - # ------------------------------------------------------------------ - - def close_gate(self) -> None: - """Block new requests from proceeding. In-flight requests continue.""" - self._gate.clear() - - def open_gate(self) -> None: - """Allow new requests to proceed.""" - self._gate.set() - - async def wait_for_drain(self) -> None: - """Wait until all in-flight requests have completed.""" - if self._active_requests == 0: - return - self._drained.clear() - await self._drained.wait() - - def _on_request_start(self) -> None: - self._active_requests += 1 - self._drained.clear() - - def _on_request_end(self) -> None: - self._active_requests -= 1 - if self._active_requests <= 0: - self._active_requests = 0 - self._drained.set() - # ------------------------------------------------------------------ # Main entrypoint # ------------------------------------------------------------------ @@ -152,12 +116,7 @@ async def _ensure_started(self) -> None: async def handle(self, request: Request) -> Response: """Proxy *request* to an inference worker, capture trace, return response.""" - await self._gate.wait() - self._on_request_start() - try: - return await self._handle_inner(request) - finally: - self._on_request_end() + return await self._handle_inner(request) async def _handle_inner(self, request: Request) -> Response: await self._ensure_started() diff --git a/rllm-model-gateway/src/rllm_model_gateway/server.py b/rllm-model-gateway/src/rllm_model_gateway/server.py index 55ba85f6f..74b3cb18b 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/server.py +++ b/rllm-model-gateway/src/rllm_model_gateway/server.py @@ -270,21 +270,6 @@ async def reload(): # Placeholder for hot-reload return {"status": "ok"} - @app.post("/admin/gate/close") - async def gate_close(): - proxy.close_gate() - return {"status": "closed"} - - @app.post("/admin/gate/open") - async def gate_open(): - proxy.open_gate() - return {"status": "open"} - - @app.post("/admin/gate/drain") - async def gate_drain(): - await proxy.wait_for_drain() - return {"status": "drained"} - # -- Proxy catch-all (must be last) ------------------------------------ @app.api_route( diff --git a/rllm/experimental/engine/gateway_manager.py b/rllm/experimental/engine/gateway_manager.py index 68524acca..abcc24191 100644 --- a/rllm/experimental/engine/gateway_manager.py +++ b/rllm/experimental/engine/gateway_manager.py @@ -167,20 +167,6 @@ def stop(self) -> None: self._local_handler = None - # -- Gate (weight sync) --------------------------------------------------- - - def close_gate(self) -> None: - """Stop forwarding new inference requests through the gateway.""" - self.client.close_gate() - - async def wait_for_drain(self) -> None: - """Wait for all in-flight inference requests to complete.""" - await self.async_client.wait_for_drain() - - def open_gate(self) -> None: - """Resume forwarding inference requests through the gateway.""" - self.client.open_gate() - # -- Session / trace API ------------------------------------------------- def create_session(self, session_id: str, is_validation: bool = False) -> str: diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index d5ce100ac..71bfa6e16 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -37,7 +37,6 @@ class BackendProtocol(ABC, Generic[TDataset, TBatch]): name: str = "base_backend" requires_loop: bool = False - needs_weight_sync_gate: bool = True def __init__(self, config: DictConfig, **kwargs): """Initialize the backend. diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index eaa2d38a0..4771d83f3 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,4 +1,3 @@ -import asyncio import logging from dataclasses import dataclass from typing import TYPE_CHECKING @@ -75,64 +74,16 @@ class RolloutEngine: is_validation: bool = False # flag enabled/disabled by AgentWorkflowEngine.execute_tasks def __init__(self, *args, **kwargs): - # Gate mechanism for pausing model calls during weight sync - self._gate: asyncio.Event = asyncio.Event() - self._gate.set() # open by default - self._active_calls: int = 0 - self._drained_event: asyncio.Event = asyncio.Event() - self._drained_event.set() # initially drained (no active calls) self.weight_version: int = 0 - # --- Gate mechanism --- - - def close_gate(self) -> None: - """Close the gate. New model calls will block at wait_for_gate().""" - logger.info(f"[RolloutEngine] Closing gate. Active calls: {self._active_calls}") - self._gate.clear() - - def open_gate(self) -> None: - """Open the gate, releasing any blocked model calls.""" - logger.info(f"[RolloutEngine] Opening gate. Active calls: {self._active_calls}") - self._gate.set() - - def on_model_call_complete(self) -> None: - """Unregister active call. Engines will call this at the END of get_model_response().""" - self._active_calls -= 1 - if self._active_calls <= 0: - self._active_calls = 0 - self._drained_event.set() - logger.debug("[RolloutEngine] All active calls drained.") - else: - logger.debug(f"[RolloutEngine] Model call complete. Active calls: {self._active_calls}") - - async def wait_for_gate(self) -> None: - """Wait until gate is open, then register as active call. Engines will call this at the START of get_model_response().""" - if not self._gate.is_set(): - logger.info(f"[RolloutEngine] Waiting for gate to open. Active calls: {self._active_calls}") - await self._gate.wait() - self._active_calls += 1 - self._drained_event.clear() - logger.debug(f"[RolloutEngine] Gate passed. Active calls: {self._active_calls}") - - async def wait_for_drain(self) -> None: - """Wait until all active model calls complete. Used during weight sync.""" - if not self._drained_event.is_set(): - logger.info(f"[RolloutEngine] Waiting for drain. Active calls: {self._active_calls}") - await self._drained_event.wait() - # --- Model response --- async def _get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: raise NotImplementedError(f"_get_model_response is not implemented for {self.__class__.__name__}") async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput: - await self.wait_for_gate() - try: - weight_version = self.weight_version - result = await self._get_model_response(messages, **kwargs) - result.weight_version = weight_version - return result - finally: - self.on_model_call_complete() + result = await self._get_model_response(messages, **kwargs) + result.weight_version = self.weight_version + return result def assemble_model_output(self, token_input: TokenInput, token_output: TokenOutput) -> ModelOutput: """ diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 732b02d9d..9a4c10b3c 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -708,26 +708,8 @@ async def _training_loop( break async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: SyncCoordinator, rollout_engine: RolloutEngine | None) -> None: - """Synchronize weights between training and rollout engines. - - Gating behavior depends on backend.needs_weight_sync_gate: - - False (e.g. Tinker): skip gating, just update weights in-place. - - True + partial_rollout=True: gate at model-call level (rollout engine or gateway). - Workflows block between turns, resume with new weights. - - True + partial_rollout=False: pause at dispatch level (coordinator). - Workflows finish naturally, gate stays open. - """ - gateway = getattr(self.agent_workflow_engine, "gateway", None) - - if self.async_config.partial_rollout: - if self.backend.needs_weight_sync_gate: - if rollout_engine is not None: - rollout_engine.close_gate() - await rollout_engine.wait_for_drain() - elif gateway is not None: - gateway.close_gate() - await gateway.wait_for_drain() - else: + """Synchronize weights between training and rollout engines.""" + if not self.async_config.partial_rollout: coordinator.pause_generation() await self._wait_for_drain() @@ -737,13 +719,7 @@ async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: S rollout_engine.weight_version = trainer_state.weight_version coordinator.on_sync_complete() - if self.async_config.partial_rollout: - if self.backend.needs_weight_sync_gate: - if rollout_engine is not None: - rollout_engine.open_gate() - elif gateway is not None: - gateway.open_gate() - else: + if not self.async_config.partial_rollout: coordinator.resume_generation() async def _wait_for_drain(self) -> None: diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index e0689ff7d..d254c7584 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -67,7 +67,6 @@ class TinkerBackend(BackendProtocol[Iterable, list[tinker.Datum]]): name: str = "tinker" requires_loop: bool = True # Tinker uses async operations - needs_weight_sync_gate: bool = False # Tinker swaps sampling_client in-place, no gating needed def __init__( self, From bc7c37fab672bb021274636ee2f8d6f80f7f1a21 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 14:33:28 -0700 Subject: [PATCH 18/21] refactor: move task tracking to coordinator, revert validation rename, cleanup - Move _in_flight_tasks tracking from UnifiedTrainer to SyncCoordinator - Add epoch start/end hooks to async generation loop - Remove dead _EPISODE_STRIP_KEYS constants from buffer - Revert is_validation rename in engine/ (defer to future PR) - Restore rllm-model-gateway/ to main Co-Authored-By: Claude Opus 4.6 (1M context) --- .../src/rllm_model_gateway/client.py | 2 -- .../src/rllm_model_gateway/proxy.py | 3 --- rllm/engine/agent_sdk_engine.py | 4 ++-- rllm/engine/agent_workflow_engine.py | 4 ++-- rllm/experimental/buffer.py | 4 ---- rllm/experimental/protocol.py | 13 ++++--------- rllm/experimental/sync_coordinator.py | 15 +++++++++++++++ rllm/experimental/unified_trainer.py | 19 +++++++------------ 8 files changed, 30 insertions(+), 34 deletions(-) diff --git a/rllm-model-gateway/src/rllm_model_gateway/client.py b/rllm-model-gateway/src/rllm_model_gateway/client.py index 8dffab091..48292fe8d 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/client.py +++ b/rllm-model-gateway/src/rllm_model_gateway/client.py @@ -141,7 +141,6 @@ def health(self) -> dict[str, Any]: return resp.json() - class AsyncGatewayClient: """Async variant of :class:`GatewayClient` using ``httpx.AsyncClient``.""" @@ -267,4 +266,3 @@ async def health(self) -> dict[str, Any]: resp = await self._http.get(f"{self.gateway_url}/health") resp.raise_for_status() return resp.json() - diff --git a/rllm-model-gateway/src/rllm_model_gateway/proxy.py b/rllm-model-gateway/src/rllm_model_gateway/proxy.py index 00429746a..3d9c3430f 100644 --- a/rllm-model-gateway/src/rllm_model_gateway/proxy.py +++ b/rllm-model-gateway/src/rllm_model_gateway/proxy.py @@ -116,9 +116,6 @@ async def _ensure_started(self) -> None: async def handle(self, request: Request) -> Response: """Proxy *request* to an inference worker, capture trace, return response.""" - return await self._handle_inner(request) - - async def _handle_inner(self, request: Request) -> Response: await self._ensure_started() session_id: str | None = request.state.session_id originally_requested_logprobs: bool = getattr(request.state, "originally_requested_logprobs", False) diff --git a/rllm/engine/agent_sdk_engine.py b/rllm/engine/agent_sdk_engine.py index ce3622de9..353d6e1ac 100644 --- a/rllm/engine/agent_sdk_engine.py +++ b/rllm/engine/agent_sdk_engine.py @@ -449,11 +449,11 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": self.rollout_engine.wake_up() if batch.meta_info.get("validate", False): - self.rollout_engine.is_validation = True + self.rollout_engine.validate = True tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() episodes = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes - self.rollout_engine.is_validation = False + self.rollout_engine.validate = False if isinstance(self.rollout_engine, VerlEngine): await self.rollout_engine.sleep() diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index 2d1752760..367034f71 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -220,14 +220,14 @@ async def execute_tasks_verl(self, batch: DataProto, **kwargs) -> DataProto: is_validation = batch.meta_info.get("validate", False) if is_validation: - self.rollout_engine.is_validation = True + self.rollout_engine.validate = True self.current_mode = "val" else: self.current_mode = "train" tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes - self.rollout_engine.is_validation = False + self.rollout_engine.validate = False await self.rollout_engine.sleep() diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py index 4ac3d0623..8e20df1a1 100644 --- a/rllm/experimental/buffer.py +++ b/rllm/experimental/buffer.py @@ -29,10 +29,6 @@ logger = logging.getLogger(__name__) -_EPISODE_STRIP_KEYS = {"prompt_ids", "response_ids", "logprobs", "model_output", "routing_matrices"} -_EPISODE_STRIP_LIST_DEFAULTS = {"prompt_ids", "response_ids", "logprobs"} - - @dataclass class TaskBatch: """All trajectory groups produced from one task's episodes, plus stripped episodes for UI logging.""" diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index 71bfa6e16..b446cdc2d 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -205,6 +205,10 @@ async def on_epoch_end(self, trainer_state: TrainerState) -> None: """Hook method called at the end of an epoch.""" pass + async def on_policy_updated(self, trainer_state: TrainerState) -> None: + """Hook called immediately after update_policy() for weight sync.""" + pass + async def on_validation_start(self, trainer_state: TrainerState) -> bool: """Hook method called at the start of validation. @@ -214,15 +218,6 @@ async def on_validation_start(self, trainer_state: TrainerState) -> bool: trainer_state.is_training = False return True - async def on_policy_updated(self, trainer_state: TrainerState) -> None: - """Hook called immediately after update_policy(). Backends sync weights here. - - For Tinker-like remote/distributed backends: save checkpoint, create new sampling_client. - For Verl-like colocated backends: trigger NCCL sync to rollout workers. - Default: no-op (sync mode uses on_batch_end for this). - """ - pass - async def on_validation_end(self, trainer_state: TrainerState) -> None: """Hook method called at the end of validation.""" trainer_state.is_training = True diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index 066cfc5ac..33ee13102 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -44,6 +44,9 @@ def __init__(self, config: SyncCoordinatorConfig): self._generation_paused: asyncio.Event = asyncio.Event() self._generation_paused.set() + # Tracks in-flight async rollout tasks for drain/wait logic + self._in_flight_tasks: set[asyncio.Task] = set() + @property def weight_version(self) -> int: return self._weight_version @@ -102,6 +105,18 @@ def resume_generation(self) -> None: async def wait_for_generation_allowed(self) -> None: await self._generation_paused.wait() + # --- In-flight task tracking --- + + def track_task(self, task: asyncio.Task) -> None: + """Register an in-flight rollout task.""" + self._in_flight_tasks.add(task) + task.add_done_callback(self._in_flight_tasks.discard) + + async def wait_for_drain(self) -> None: + """Wait for all in-flight rollout tasks to complete.""" + while self._in_flight_tasks: + await asyncio.sleep(0.1) + def stats(self) -> dict: return { "async/weight_version": self._weight_version, diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 9a4c10b3c..1abdaa2c3 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -247,8 +247,6 @@ def __init__( if hasattr(self.backend, "tokenizer"): self.tokenizer = self.backend.tokenizer - # Tracks in-flight async rollout tasks for drain/wait logic - self._in_flight_tasks: set[asyncio.Task] = set() def _validate_and_setup_configs(self): """Validate and setup common configs.""" @@ -542,6 +540,7 @@ async def _generation_loop( try: for epoch in range(self.rllm_config.trainer.total_epochs): + await self.backend.on_epoch_start(trainer_state) train_dataloader = self.backend.get_dataloader(self.train_dataset, trainer_state) self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) @@ -561,10 +560,11 @@ async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): ) await buffer.add_episode(tid, episode) t = asyncio.create_task(_run_rollout()) - self._in_flight_tasks.add(t) - t.add_done_callback(self._in_flight_tasks.discard) + coordinator.track_task(t) - await self._wait_for_drain() + await self.backend.on_epoch_end(trainer_state) + + await coordinator.wait_for_drain() finally: buffer.mark_generation_complete() @@ -711,7 +711,7 @@ async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: S """Synchronize weights between training and rollout engines.""" if not self.async_config.partial_rollout: coordinator.pause_generation() - await self._wait_for_drain() + await coordinator.wait_for_drain() trainer_state.weight_version = coordinator.weight_version + 1 await self.backend.on_policy_updated(trainer_state) @@ -722,15 +722,10 @@ async def _perform_weight_sync(self, trainer_state: TrainerState, coordinator: S if not self.async_config.partial_rollout: coordinator.resume_generation() - async def _wait_for_drain(self) -> None: - """Wait for all in-flight rollout tasks to complete.""" - while self._in_flight_tasks: - await asyncio.sleep(0.1) - async def _validate_async_with_pause(self, trainer_state: TrainerState, coordinator: SyncCoordinator) -> dict: """Validation with dispatch-level pause. Waits for workflows to drain, then runs validation.""" coordinator.pause_generation() - await self._wait_for_drain() + await coordinator.wait_for_drain() try: return await self._validate_async(trainer_state) finally: From 7550fdad0f9ca5dcbc306687b2b5e4e70114b21c Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 14:39:31 -0700 Subject: [PATCH 19/21] restore load_balancer assertion in verl_engine, revert tool_base to main Co-Authored-By: Claude Opus 4.6 (1M context) --- rllm/experimental/rollout/verl_engine.py | 2 ++ rllm/tools/tool_base.py | 31 ------------------------ 2 files changed, 2 insertions(+), 31 deletions(-) diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py index 4c08073e7..a9b5d99e3 100644 --- a/rllm/experimental/rollout/verl_engine.py +++ b/rllm/experimental/rollout/verl_engine.py @@ -18,6 +18,8 @@ def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokeni if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: raise ValueError(f"VerlEngine only supports vllm or sglang rollout, but got {config.actor_rollout_ref.rollout.name}") + assert rollout_manager.global_load_balancer is not None, "global_load_balancer is not available. Issues with RayPPOTrainer's `init_workers()` function." + self.rollout_manager: AgentLoopManager = rollout_manager # reconstruct the servers list from the server_addresses and server_handles (Verl 0.7.0+) servers = zip(rollout_manager.server_addresses, rollout_manager.server_handles, strict=True) diff --git a/rllm/tools/tool_base.py b/rllm/tools/tool_base.py index 446c599d2..91b8a2ec8 100644 --- a/rllm/tools/tool_base.py +++ b/rllm/tools/tool_base.py @@ -1,5 +1,4 @@ import inspect -import json from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -11,22 +10,10 @@ class ToolCall: name: str arguments: dict[str, Any] - id: str | None = None - metadata: dict | None = None def to_dict(self): return {"name": self.name, "arguments": self.arguments} - def to_openai_format(self): - return { - "id": self.id or "unknown", - "type": "function", - "function": { - "name": self.name, - "arguments": json.dumps(self.arguments), - }, - } - @dataclass class ToolOutput: @@ -58,24 +45,6 @@ def to_string(self) -> str: """ return str(self) - def to_dict(self) -> dict: - """Convert the tool output to a dictionary for JSON serialization.""" - return { - "name": self.name, - "output": self.output, - "error": self.error, - "metadata": self.metadata, - } - - def to_openai_format(self) -> dict: - """Convert the tool output to OpenAI tool message format.""" - tool_call_id = (self.metadata.get("call_id") if self.metadata else None) or "unknown" - return { - "role": "tool", - "content": self.to_string(), - "tool_call_id": tool_call_id, - } - class Tool: """ From 4f05c8e99a2a899d9107c2c7fe35eeed29fda7f8 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sat, 4 Apr 2026 15:38:29 -0700 Subject: [PATCH 20/21] fix: add future annotations to rollout_engine for TYPE_CHECKING imports Co-Authored-By: Claude Opus 4.6 (1M context) --- rllm/experimental/rollout/rollout_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rllm/experimental/rollout/rollout_engine.py b/rllm/experimental/rollout/rollout_engine.py index 4771d83f3..3837d59ec 100644 --- a/rllm/experimental/rollout/rollout_engine.py +++ b/rllm/experimental/rollout/rollout_engine.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import dataclass from typing import TYPE_CHECKING From 799324364be5b433717dc40cdb1da5641600d843 Mon Sep 17 00:00:00 2001 From: Kyle Montgomery Date: Sun, 5 Apr 2026 13:45:35 -0700 Subject: [PATCH 21/21] style: fix ruff lint and format issues on unified-fully-async branch Auto-fixed import sorting, unused imports, and formatting across 13 files. Manual fixes: TYPE_CHECKING import for tqdm in buffer.py, isinstance union syntax in metrics.py, moved logger below imports in unified_trainer.py, split long log line. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../train_countdown_unified_tinker.py | 2 +- rllm/experimental/buffer.py | 28 ++++++++-------- rllm/experimental/common/config.py | 4 +-- .../engine/remote_agent_flow_engine.py | 12 +++++-- .../engine/unified_workflow_engine.py | 5 +-- rllm/experimental/metrics.py | 7 ++-- rllm/experimental/protocol.py | 2 +- rllm/experimental/sync_coordinator.py | 6 ++-- rllm/experimental/unified_trainer.py | 32 +++++++++---------- rllm/experimental/verl/verl_backend.py | 2 +- rllm/trainer/tinker/tinker_backend.py | 1 - rllm/trainer/tinker/tinker_metrics_utils.py | 1 - rllm/trainer/tinker/transform.py | 15 +++------ 13 files changed, 54 insertions(+), 63 deletions(-) diff --git a/examples/countdown/unified_trainer/train_countdown_unified_tinker.py b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py index 9b114ed79..dadc98b20 100644 --- a/examples/countdown/unified_trainer/train_countdown_unified_tinker.py +++ b/examples/countdown/unified_trainer/train_countdown_unified_tinker.py @@ -1,8 +1,8 @@ import hydra from rllm.data.dataset import DatasetRegistry -from rllm.rewards.countdown_reward import countdown_reward_fn from rllm.experimental.unified_trainer import AgentTrainer +from rllm.rewards.countdown_reward import countdown_reward_fn from rllm.workflows.simple_workflow import SimpleWorkflow diff --git a/rllm/experimental/buffer.py b/rllm/experimental/buffer.py index 8e20df1a1..288c72fcc 100644 --- a/rllm/experimental/buffer.py +++ b/rllm/experimental/buffer.py @@ -12,6 +12,10 @@ import pickle import tempfile from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tqdm import tqdm from rllm.agents.agent import Episode, TrajectoryGroup from rllm.experimental.common import ( @@ -32,6 +36,7 @@ @dataclass class TaskBatch: """All trajectory groups produced from one task's episodes, plus stripped episodes for UI logging.""" + groups: list[TrajectoryGroup] episodes: list[Episode] = field(default_factory=list) @@ -66,7 +71,7 @@ def __init__( rs_config: RejectionSamplingConfig, episode_offload_dir: str | None = None, trajectory_group_offload_dir: str | None = None, - pbar: "tqdm | None" = None, + pbar: tqdm | None = None, ): self._group_size = group_size self._coordinator = coordinator @@ -161,7 +166,9 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: # 2. Transform episodes -> trajectory groups traj_groups, transform_metrics = transform_episodes_to_trajectory_groups( - episodes, self._transform_config, self._cf_config, + episodes, + self._transform_config, + self._cf_config, ) self._aggregator.record_dict(transform_metrics) @@ -176,7 +183,8 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: # 4. Compute advantages adv_metrics = collect_reward_and_advantage_from_trajectory_groups( - traj_groups, self._algorithm_config, + traj_groups, + self._algorithm_config, ) self._aggregator.record_dict(adv_metrics) @@ -184,15 +192,7 @@ async def add_episode(self, task_id: str, episode: Episode) -> bool: filtered_zero_adv = 0 if self._rs_config.filter_uniform_groups: before_adv = len(traj_groups) - traj_groups = [ - g for g in traj_groups - if any( - abs(step.advantage) > 1e-8 - for traj in g.trajectories - for step in traj.steps - if step.advantage is not None - ) - ] + traj_groups = [g for g in traj_groups if any(abs(step.advantage) > 1e-8 for traj in g.trajectories for step in traj.steps if step.advantage is not None)] filtered_zero_adv = before_adv - len(traj_groups) self._aggregator.record("groups/dropped_zero_adv", filtered_zero_adv) @@ -264,10 +264,10 @@ def _record_episode_metrics(self, episodes: list[Episode]) -> None: @staticmethod def _min_weight_version(episodes: list[Episode]) -> int: - min_v = float('inf') + min_v = float("inf") for ep in episodes: for traj in ep.trajectories: for step in traj.steps: if step.weight_version is not None: min_v = min(min_v, step.weight_version) - return int(min_v) if min_v != float('inf') else 0 + return int(min_v) if min_v != float("inf") else 0 diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 97e414532..dee2f2d13 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -37,9 +37,7 @@ def __post_init__(self): self.fwd_bwd_group_size = self.mini_batch_size if self.enable: assert self.fwd_bwd_group_size >= 1 - assert self.mini_batch_size % self.fwd_bwd_group_size == 0, ( - f"mini_batch_size ({self.mini_batch_size}) must be divisible by fwd_bwd_group_size ({self.fwd_bwd_group_size})" - ) + assert self.mini_batch_size % self.fwd_bwd_group_size == 0, f"mini_batch_size ({self.mini_batch_size}) must be divisible by fwd_bwd_group_size ({self.fwd_bwd_group_size})" @dataclass diff --git a/rllm/experimental/engine/remote_agent_flow_engine.py b/rllm/experimental/engine/remote_agent_flow_engine.py index 327901d85..0bf67d6f6 100644 --- a/rllm/experimental/engine/remote_agent_flow_engine.py +++ b/rllm/experimental/engine/remote_agent_flow_engine.py @@ -93,7 +93,12 @@ async def execute_tasks( return episodes async def process_task_with_retry( - self, task: dict, task_id: str, rollout_idx: int, result_idx: int, **kwargs, + self, + task: dict, + task_id: str, + rollout_idx: int, + result_idx: int, + **kwargs, ) -> tuple[str, int, int, Episode]: """Process a single task with concurrency control.""" async with self._semaphore: @@ -105,7 +110,10 @@ async def process_task_with_retry( session_url = self.gateway.get_session_url(session_id) submission = TaskSubmission( - task=task, session_id=session_id, task_id=task_id, inference_url=session_url, + task=task, + session_id=session_id, + task_id=task_id, + inference_url=session_url, ) results = await self.runtime.execute_tasks([submission], timeout=self.session_timeout) result = results[0] diff --git a/rllm/experimental/engine/unified_workflow_engine.py b/rllm/experimental/engine/unified_workflow_engine.py index 085b4e8e2..037cbd9a3 100644 --- a/rllm/experimental/engine/unified_workflow_engine.py +++ b/rllm/experimental/engine/unified_workflow_engine.py @@ -12,7 +12,6 @@ from rllm.agents.agent import Episode from rllm.experimental.rollout import RolloutEngine -from rllm.utils import colorful_print from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow @@ -157,9 +156,7 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i elif len(traj.steps) > 0: reward = f"{traj.steps[-1].reward:.1f}" reward_strs.append(f"{traj.name}: {reward}") - logger.debug( - f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}" - ) + logger.debug(f"[{uid}] Rollout completed. Rewards: [{', '.join(reward_strs)}], Termination: {episode.termination_reason}") if episode.termination_reason != TerminationReason.ERROR: return task_id, rollout_idx, result_idx, episode diff --git a/rllm/experimental/metrics.py b/rllm/experimental/metrics.py index 55cfa8cd9..fe2ce1ef4 100644 --- a/rllm/experimental/metrics.py +++ b/rllm/experimental/metrics.py @@ -10,7 +10,6 @@ import numpy as np - # Keys that should be summed rather than averaged. _SUM_KEYS: set[str] = { "groups/num_trajs_before_filter", @@ -29,9 +28,7 @@ ) # Prefixes where "mean" is the correct reduction. -_MEAN_PREFIXES: tuple[str, ...] = ( - "episode/", -) +_MEAN_PREFIXES: tuple[str, ...] = ("episode/",) def _infer_rule(key: str) -> str: @@ -107,7 +104,7 @@ def record(self, key: str, value: float) -> None: def record_dict(self, metrics: dict) -> None: """Record all numeric values from a dict, coercing types.""" for k, v in metrics.items(): - if isinstance(v, (int, float)): + if isinstance(v, int | float): self._values[k].append(float(v)) elif isinstance(v, np.number): self._values[k].append(float(v)) diff --git a/rllm/experimental/protocol.py b/rllm/experimental/protocol.py index b446cdc2d..ebdb79a0b 100644 --- a/rllm/experimental/protocol.py +++ b/rllm/experimental/protocol.py @@ -16,8 +16,8 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.experimental.rollout import RolloutEngine from rllm.experimental.common.advantage import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups +from rllm.experimental.rollout import RolloutEngine if TYPE_CHECKING: from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine diff --git a/rllm/experimental/sync_coordinator.py b/rllm/experimental/sync_coordinator.py index 33ee13102..1405f3c0a 100644 --- a/rllm/experimental/sync_coordinator.py +++ b/rllm/experimental/sync_coordinator.py @@ -8,8 +8,8 @@ @dataclass class SyncCoordinatorConfig: - mini_batch_size: int # episode groups per optimizer step - group_size: int # episodes per group (rollout.n) + mini_batch_size: int # episode groups per optimizer step + group_size: int # episodes per group (rollout.n) staleness_threshold: float trigger_parameter_sync_step: int @@ -32,7 +32,7 @@ def __init__(self, config: SyncCoordinatorConfig): self._weight_version: int = 0 self._quota_used: int = 0 # groups counting toward current sync window quota (includes carryover) - self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered + self._in_flight: int = 0 # groups dispatched but not yet consumed/filtered self._steps_since_sync: int = 0 self._total_syncs: int = 0 diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 1abdaa2c3..45819f7c7 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -9,15 +9,13 @@ from pprint import pprint from typing import Any, Literal -logger = logging.getLogger(__name__) - import numpy as np from omegaconf import DictConfig, OmegaConf from tqdm import tqdm from rllm.agents.agent import Episode, TrajectoryGroup from rllm.data import Dataset -from rllm.experimental.rollout import RolloutEngine +from rllm.experimental.buffer import TrajectoryGroupBuffer from rllm.experimental.common.advantage import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, @@ -40,14 +38,16 @@ ) from rllm.experimental.common.visualization import print_metrics_table, visualize_trajectory_last_steps from rllm.experimental.engine.unified_workflow_engine import UnifiedWorkflowEngine -from rllm.experimental.buffer import TrajectoryGroupBuffer from rllm.experimental.metrics import MetricsAggregator from rllm.experimental.protocol import BackendProtocol +from rllm.experimental.rollout import RolloutEngine from rllm.experimental.sync_coordinator import SyncCoordinator, SyncCoordinatorConfig from rllm.utils import EpisodeLogger, Tracking, extract_source_metadata from rllm.workflows.store import Store from rllm.workflows.workflow import TerminationReason, Workflow +logger = logging.getLogger(__name__) + @dataclass class TrainerState: @@ -247,7 +247,6 @@ def __init__( if hasattr(self.backend, "tokenizer"): self.tokenizer = self.backend.tokenizer - def _validate_and_setup_configs(self): """Validate and setup common configs.""" # validate common, backend-agnostic configs @@ -482,12 +481,8 @@ async def _train_batch_async(self, batch: Any, trainer_state: TrainerState) -> N async def _fit_fully_async(self, trainer_state: TrainerState) -> None: """Fully-async generation + training with group-level streaming.""" - assert self.config.data.train_batch_size == 1, ( - f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" - ) - assert not getattr(self.agent_workflow_engine, "raise_on_error", False), ( - "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" - ) + assert self.config.data.train_batch_size == 1, f"Async training requires train_batch_size=1, got {self.config.data.train_batch_size}" + assert not getattr(self.agent_workflow_engine, "raise_on_error", False), "Async training requires raise_on_error=False so that process_task_with_retry always returns an episode" coord_config = SyncCoordinatorConfig( mini_batch_size=self.async_config.mini_batch_size, group_size=self.rllm_config.rollout.n, @@ -533,7 +528,10 @@ async def _fit_fully_async(self, trainer_state: TrainerState) -> None: pbar.close() async def _generation_loop( - self, trainer_state: TrainerState, buffer: TrajectoryGroupBuffer, coordinator: SyncCoordinator, + self, + trainer_state: TrainerState, + buffer: TrajectoryGroupBuffer, + coordinator: SyncCoordinator, ) -> None: """Generate episodes and stream to TrajectoryGroupBuffer.""" group_size = self.rllm_config.rollout.n @@ -554,11 +552,11 @@ async def _generation_loop( task_id = str(uuid.uuid4()) for rollout_idx in range(group_size): + async def _run_rollout(t=task, tid=task_id, ridx=rollout_idx): - _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry( - task=t, task_id=tid, rollout_idx=ridx, result_idx=0 - ) + _, _, _, episode = await self.agent_workflow_engine.process_task_with_retry(task=t, task_id=tid, rollout_idx=ridx, result_idx=0) await buffer.add_episode(tid, episode) + t = asyncio.create_task(_run_rollout()) coordinator.track_task(t) @@ -593,7 +591,9 @@ async def _training_loop( done = False buffered = buffer._queue.qsize() - logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: waiting for {mini_batch_size} task batches ({num_fwd_bwd_passes} fwd-bwd passes x {fwd_bwd_group_size} groups), {buffered} buffered") + logger.info( + f"[TrainingLoop] Step {trainer_state.global_step}: waiting for {mini_batch_size} task batches ({num_fwd_bwd_passes} fwd-bwd passes x {fwd_bwd_group_size} groups), {buffered} buffered" + ) # 1. Pull mini_batch_size task batches total, split into # num_fwd_bwd_passes forward-backward passes of fwd_bwd_group_size each. diff --git a/rllm/experimental/verl/verl_backend.py b/rllm/experimental/verl/verl_backend.py index 1929b9305..06bbd49e7 100644 --- a/rllm/experimental/verl/verl_backend.py +++ b/rllm/experimental/verl/verl_backend.py @@ -33,13 +33,13 @@ from rllm.agents.agent import Episode from rllm.data import Dataset -from rllm.experimental.rollout import RolloutEngine, VerlEngine from rllm.experimental.common import ( AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups, simple_timer, ) from rllm.experimental.protocol import BackendProtocol +from rllm.experimental.rollout import RolloutEngine, VerlEngine from rllm.experimental.verl import compute_advantage_verl, transform_episodes_to_dataproto, update_dataproto_with_advantages if TYPE_CHECKING: diff --git a/rllm/trainer/tinker/tinker_backend.py b/rllm/trainer/tinker/tinker_backend.py index d254c7584..00ffd46ef 100644 --- a/rllm/trainer/tinker/tinker_backend.py +++ b/rllm/trainer/tinker/tinker_backend.py @@ -27,7 +27,6 @@ from rllm.experimental.protocol import BackendProtocol from rllm.experimental.rollout import RolloutEngine, TinkerEngine from rllm.trainer.tinker.tinker_metrics_utils import ( - print_metrics_table, update_training_metrics, ) from rllm.trainer.tinker.tinker_policy_trainer import TinkerPolicyTrainer diff --git a/rllm/trainer/tinker/tinker_metrics_utils.py b/rllm/trainer/tinker/tinker_metrics_utils.py index 805ed9ad4..618a6ce30 100644 --- a/rllm/trainer/tinker/tinker_metrics_utils.py +++ b/rllm/trainer/tinker/tinker_metrics_utils.py @@ -11,7 +11,6 @@ logger = logging.getLogger(__name__) - def compute_kl_and_entropy_metrics(training_datums: list[tinker.Datum], training_logprobs: list[torch.Tensor]) -> dict: """ Compute KL divergence and entropy metrics from training. diff --git a/rllm/trainer/tinker/transform.py b/rllm/trainer/tinker/transform.py index b062f14fa..13497ab63 100644 --- a/rllm/trainer/tinker/transform.py +++ b/rllm/trainer/tinker/transform.py @@ -11,9 +11,9 @@ from tinker_cookbook.supervised.common import create_rightshifted_model_input_and_leftshifted_targets from rllm.agents.agent import Trajectory, TrajectoryGroup +from rllm.experimental.common import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups from rllm.experimental.rollout.tinker_engine import _flat_token_input_length, _flat_token_input_to_model_input from rllm.experimental.rollout.types import TinkerTokenInput -from rllm.experimental.common import AlgorithmConfig, collect_reward_and_advantage_from_trajectory_groups def _is_prefix(seq1: TinkerTokenInput, seq2: TinkerTokenInput) -> bool: @@ -125,10 +125,7 @@ def make_datum_from_state(): SequenceAccumulator.mask.extend([0.0] * delta_token_input_length + [1.0] * len(output_token_ids)) if router_replay: step_rm = step.routing_matrices or [] - SequenceAccumulator.routing_matrices.extend( - [""] * delta_token_input_length - + (list(step_rm) if step_rm else [""] * len(output_token_ids)) - ) + SequenceAccumulator.routing_matrices.extend([""] * delta_token_input_length + (list(step_rm) if step_rm else [""] * len(output_token_ids))) if SequenceAccumulator.full_sequence: data.append(make_datum_from_state()) @@ -149,12 +146,7 @@ def transform_trajectory_groups_to_datums( Otherwise, we return a list of datums. """ # step 1: compute advantages (skip if already pre-computed by buffer) - has_advantages = any( - step.advantage is not None - for group in trajectory_groups - for traj in group.trajectories - for step in traj.steps - ) + has_advantages = any(step.advantage is not None for group in trajectory_groups for traj in group.trajectories for step in traj.steps) if has_advantages: adv_metrics = {} else: @@ -181,6 +173,7 @@ def transform_trajectory_groups_to_datums( if seqs_per_traj: import numpy as _np + adv_metrics["batch/seqs_per_traj/mean"] = _np.mean(seqs_per_traj) adv_metrics["batch/seqs_per_traj/min"] = _np.min(seqs_per_traj) adv_metrics["batch/seqs_per_traj/max"] = _np.max(seqs_per_traj)