diff --git a/rllm/experimental/rollout/verl_engine.py b/rllm/experimental/rollout/verl_engine.py index a9b5d99e3..705b05ee6 100644 --- a/rllm/experimental/rollout/verl_engine.py +++ b/rllm/experimental/rollout/verl_engine.py @@ -13,6 +13,7 @@ class VerlEngine(RolloutEngine): def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs): + super().__init__() self.config = config if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]: diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index 45819f7c7..5ea0523c7 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -543,7 +543,15 @@ async def _generation_loop( self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch) for batch in train_dataloader: - task = batch[0] + # Extract single task from batch (train_batch_size=1). + # verl dataloader yields dicts (with extra_info, raw_prompt keys), + # not indexable lists. + if isinstance(batch, dict): + task = batch.get("extra_info", [{}])[0] + elif hasattr(batch, 'non_tensor_batch'): + task = batch.non_tensor_batch.get("extra_info", [{}])[0] + else: + task = batch[0] await coordinator.wait_for_generation_allowed() if not coordinator.has_quota(): @@ -622,12 +630,14 @@ async def _training_loop( # Forward-backward on this chunk trainer_state.trajectory_groups = chunk_groups + trainer_state.episodes = all_episodes if trainer_state.has_trajectory_groups: logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: fwd-bwd pass {pass_idx + 1}/{num_fwd_bwd_passes} ({len(chunk_groups)} groups)") await self.backend.on_batch_start(trainer_state) trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state) await self.backend.process_backend_batch(trainer_state) + await self.backend.compute_advantages(trainer_state, self.algorithm_config) # Drain per-chunk backend metrics into aggregator aggregator.record_dict(trainer_state.metrics) @@ -686,7 +696,10 @@ async def _training_loop( step_time = trainer_state.metrics.get("time/step", 1.0) trainer_state.metrics["async/trainer_idle_ratio"] = buffer_wait_time / max(step_time, 1e-9) - # 7. on_batch_end writes backend metrics (progress, optim, timing) + # 7. on_batch_end writes backend metrics (progress, optim, timing). + # Populate timing_dict["step"] so verl's compute_throughout_metrics can compute throughput. + # The sync path uses `with simple_timer("step", timing_dict)` but async doesn't. + trainer_state.timing_dict["step"] = time.perf_counter() - step_start await self.backend.on_batch_end(trainer_state) # 7. Print and log