From 924bcf5f33dc779ea1db1763791faa6de9634519 Mon Sep 17 00:00:00 2001 From: yifannnwu Date: Thu, 2 Apr 2026 20:33:07 +0000 Subject: [PATCH] fix: support verl 0.7.1 new EngineWorker in agent_workflow_trainer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit verl 0.7.1 defaults to `use_legacy_worker_impl: disable`, which uses `EngineWorker` instead of `TrainingWorker`. This changes the worker API: - `compute_log_prob` / `compute_ref_log_prob` / `update_actor` now return `TensorDict` instead of `DataProto` - Workers operate in no-padding format internally, so outputs must be converted back via `no_padding_2_padding` - `ppo_loss` requires `global_batch_size`, `temperature` etc. in the batch TensorDict; `compute_log_prob` needs `compute_loss=False` and `calculate_entropy=True` Without this fix, training crashes with: KeyError: 'temperature' / 'global_batch_size' AttributeError: 'TensorDict' object has no attribute 'batch' RuntimeError: tensor size mismatch (no-padding vs padding) Changes: - compute_log_prob: convert DataProto→TensorDict→no-padding before call, set compute_loss=False + calculate_entropy=True, convert output back to padded DataProto with old_log_probs/entropys keys - compute_ref_log_prob: same TensorDict handling + no_padding_2_padding - update_actor: inject mini_batch_size, epochs, seed, global_batch_size, temperature, calculate_entropy into batch before call; handle TensorDict return for metrics extraction - validation/distillation compute_log_prob: same pattern - All changes are backward-compatible (isinstance checks for TensorDict vs DataProto returns) --- rllm/trainer/verl/agent_workflow_trainer.py | 84 +++++++++++++++++++-- 1 file changed, 76 insertions(+), 8 deletions(-) diff --git a/rllm/trainer/verl/agent_workflow_trainer.py b/rllm/trainer/verl/agent_workflow_trainer.py index 3c19a74b8..279461d09 100644 --- a/rllm/trainer/verl/agent_workflow_trainer.py +++ b/rllm/trainer/verl/agent_workflow_trainer.py @@ -9,6 +9,7 @@ import numpy as np import torch from omegaconf import OmegaConf +from tensordict import TensorDict from verl import DataProto from verl.protocol import pad_dataproto_to_divisor from verl.single_controller.ray import RayWorkerGroup @@ -30,6 +31,8 @@ ) from verl.trainer.ppo.utils import Role, WorkerType from verl.utils.debug import marked_timer +from verl.utils import tensordict_utils as tu +from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding from rllm.engine.agent_workflow_engine import AgentWorkflowEngine from rllm.engine.rollout.verl_engine import VerlEngine @@ -325,10 +328,29 @@ def fit_agent(self): continue images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist()) batch.meta_info["images_seqlens"] = images_seqlens_all + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature # recompute old_log_probs + # Follow verl's _compute_old_log_prob pattern for new worker path: + # convert to TensorDict + no-padding, send, convert output back. with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch_td = batch.to_tensordict() + batch_td = left_right_2_no_padding(batch_td) + tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False, + temperature=self.config.actor_rollout_ref.rollout.temperature) + old_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch_td) + # New worker returns TensorDict in no-padding format + if isinstance(old_log_prob_output, TensorDict): + entropy = tu.get(old_log_prob_output, "entropy") + log_probs = tu.get(old_log_prob_output, "log_probs") + # Convert from no-padding back to padding format + entropy = no_padding_2_padding(entropy, batch_td) + log_probs = no_padding_2_padding(log_probs, batch_td) + old_log_prob = DataProto.from_tensordict( + tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()}) + ) + else: + old_log_prob = old_log_prob_output entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode @@ -345,12 +367,20 @@ def fit_agent(self): metrics.update(debug_metrics) if self.use_reference_policy: - # compute reference log_prob + # compute reference log_prob (reuse batch_td from old_log_prob) with marked_timer("ref", timing_raw, color="olive"): if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(batch_td) + else: + ref_log_prob_output = self.actor_rollout_wg.compute_ref_log_prob(batch_td) + if isinstance(ref_log_prob_output, TensorDict): + ref_lp = tu.get(ref_log_prob_output, "log_probs") + ref_lp = no_padding_2_padding(ref_lp, batch_td) + ref_log_prob = DataProto.from_tensordict( + tu.get_tensordict({"ref_log_prob": ref_lp.float()}) + ) else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + ref_log_prob = ref_log_prob_output batch = batch.union(ref_log_prob) # compute values @@ -432,6 +462,26 @@ def fit_agent(self): # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: + # verl 0.7.1 new worker path: update_actor needs training metadata + ppo_mini_batch_size = ( + self.config.actor_rollout_ref.actor.ppo_mini_batch_size + * self.config.actor_rollout_ref.rollout.n + ) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + batch.meta_info["mini_batch_size"] = ppo_mini_batch_size + batch.meta_info["epochs"] = self.config.actor_rollout_ref.actor.ppo_epochs + batch.meta_info["seed"] = self.config.actor_rollout_ref.actor.data_loader_seed + batch.meta_info["dataloader_kwargs"] = { + "shuffle": self.config.actor_rollout_ref.actor.shuffle, + } + batch.meta_info["compute_loss"] = True + tu.assign_non_tensor( + batch.batch, + temperature=self.config.actor_rollout_ref.rollout.temperature, + global_batch_size=ppo_mini_batch_size, + calculate_entropy=(self.config.actor_rollout_ref.actor.entropy_coeff != 0.0), + ) + # update actor with marked_timer("update_actor", timing_raw, color="red"): actor_output = self.actor_rollout_wg.update_actor(batch) @@ -445,7 +495,11 @@ def fit_agent(self): with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights(self.global_steps) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + # verl 0.7.1 new worker returns TensorDict; extract metrics + if isinstance(actor_output, TensorDict): + actor_output_metrics = reduce_metrics(tu.get(actor_output, "metrics")) + else: + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # Visualize some sample trajectories @@ -621,9 +675,23 @@ def _validate_agent(self): try: # Concatenate all validation batches combined_batch = DataProto.concat(batches_for_distill) - - # Compute old_log_probs for distillation - old_log_prob = self.actor_rollout_wg.compute_log_prob(combined_batch) + combined_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + # Follow verl's no-padding pattern for compute_log_prob + cb_td = combined_batch.to_tensordict() + cb_td = left_right_2_no_padding(cb_td) + tu.assign_non_tensor(cb_td, calculate_entropy=True, compute_loss=False, + temperature=self.config.actor_rollout_ref.rollout.temperature) + + old_log_prob_output = self.actor_rollout_wg.compute_log_prob(cb_td) + if isinstance(old_log_prob_output, TensorDict): + log_probs = tu.get(old_log_prob_output, "log_probs") + log_probs = no_padding_2_padding(log_probs, cb_td) + old_log_prob = DataProto.from_tensordict( + tu.get_tensordict({"old_log_probs": log_probs.float()}) + ) + else: + old_log_prob = old_log_prob_output combined_batch = combined_batch.union(old_log_prob) # Compute distillation advantages