Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rllm/experimental/rollout/verl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class VerlEngine(RolloutEngine):
def __init__(self, config: DictConfig, rollout_manager: AgentLoopManager, tokenizer: Tokenizer, processor=None, **kwargs):
super().__init__()
self.config = config

if config.actor_rollout_ref.rollout.name not in ["vllm", "sglang"]:
Expand Down
17 changes: 15 additions & 2 deletions rllm/experimental/unified_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,15 @@ async def _generation_loop(
self.agent_workflow_engine.set_training_step(trainer_state.global_step, mode="train", epoch=epoch)

for batch in train_dataloader:
task = batch[0]
# Extract single task from batch (train_batch_size=1).
# verl dataloader yields dicts (with extra_info, raw_prompt keys),
# not indexable lists.
if isinstance(batch, dict):
task = batch.get("extra_info", [{}])[0]
elif hasattr(batch, 'non_tensor_batch'):
task = batch.non_tensor_batch.get("extra_info", [{}])[0]
else:
task = batch[0]

await coordinator.wait_for_generation_allowed()
if not coordinator.has_quota():
Expand Down Expand Up @@ -622,12 +630,14 @@ async def _training_loop(

# Forward-backward on this chunk
trainer_state.trajectory_groups = chunk_groups
trainer_state.episodes = all_episodes

if trainer_state.has_trajectory_groups:
logger.info(f"[TrainingLoop] Step {trainer_state.global_step}: fwd-bwd pass {pass_idx + 1}/{num_fwd_bwd_passes} ({len(chunk_groups)} groups)")
await self.backend.on_batch_start(trainer_state)
trainer_state.backend_batch = self.backend.transform_to_backend_batch(trainer_state)
await self.backend.process_backend_batch(trainer_state)
await self.backend.compute_advantages(trainer_state, self.algorithm_config)

# Drain per-chunk backend metrics into aggregator
aggregator.record_dict(trainer_state.metrics)
Expand Down Expand Up @@ -686,7 +696,10 @@ async def _training_loop(
step_time = trainer_state.metrics.get("time/step", 1.0)
trainer_state.metrics["async/trainer_idle_ratio"] = buffer_wait_time / max(step_time, 1e-9)

# 7. on_batch_end writes backend metrics (progress, optim, timing)
# 7. on_batch_end writes backend metrics (progress, optim, timing).
# Populate timing_dict["step"] so verl's compute_throughout_metrics can compute throughput.
# The sync path uses `with simple_timer("step", timing_dict)` but async doesn't.
trainer_state.timing_dict["step"] = time.perf_counter() - step_start
await self.backend.on_batch_end(trainer_state)

# 7. Print and log
Expand Down
Loading