Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 76 additions & 8 deletions rllm/trainer/verl/agent_workflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor
from verl.single_controller.ray import RayWorkerGroup
Expand All @@ -30,6 +31,8 @@
)
from verl.trainer.ppo.utils import Role, WorkerType
from verl.utils.debug import marked_timer
from verl.utils import tensordict_utils as tu
from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding

from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout.verl_engine import VerlEngine
Expand Down Expand Up @@ -325,10 +328,29 @@ def fit_agent(self):
continue
images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist())
batch.meta_info["images_seqlens"] = images_seqlens_all
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature

# recompute old_log_probs
# Follow verl's _compute_old_log_prob pattern for new worker path:
# convert to TensorDict + no-padding, send, convert output back.
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch_td = batch.to_tensordict()
batch_td = left_right_2_no_padding(batch_td)
tu.assign_non_tensor(batch_td, calculate_entropy=True, compute_loss=False,
temperature=self.config.actor_rollout_ref.rollout.temperature)
old_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch_td)
# New worker returns TensorDict in no-padding format
if isinstance(old_log_prob_output, TensorDict):
entropy = tu.get(old_log_prob_output, "entropy")
log_probs = tu.get(old_log_prob_output, "log_probs")
# Convert from no-padding back to padding format
entropy = no_padding_2_padding(entropy, batch_td)
log_probs = no_padding_2_padding(log_probs, batch_td)
old_log_prob = DataProto.from_tensordict(
tu.get_tensordict({"old_log_probs": log_probs.float(), "entropys": entropy.float()})
)
else:
old_log_prob = old_log_prob_output
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
Expand All @@ -345,12 +367,20 @@ def fit_agent(self):
metrics.update(debug_metrics)

if self.use_reference_policy:
# compute reference log_prob
# compute reference log_prob (reuse batch_td from old_log_prob)
with marked_timer("ref", timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(batch_td)
else:
ref_log_prob_output = self.actor_rollout_wg.compute_ref_log_prob(batch_td)
if isinstance(ref_log_prob_output, TensorDict):
ref_lp = tu.get(ref_log_prob_output, "log_probs")
ref_lp = no_padding_2_padding(ref_lp, batch_td)
ref_log_prob = DataProto.from_tensordict(
tu.get_tensordict({"ref_log_prob": ref_lp.float()})
)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
ref_log_prob = ref_log_prob_output
batch = batch.union(ref_log_prob)

# compute values
Expand Down Expand Up @@ -432,6 +462,26 @@ def fit_agent(self):

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# verl 0.7.1 new worker path: update_actor needs training metadata
ppo_mini_batch_size = (
self.config.actor_rollout_ref.actor.ppo_mini_batch_size
* self.config.actor_rollout_ref.rollout.n
)
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
batch.meta_info["mini_batch_size"] = ppo_mini_batch_size
batch.meta_info["epochs"] = self.config.actor_rollout_ref.actor.ppo_epochs
batch.meta_info["seed"] = self.config.actor_rollout_ref.actor.data_loader_seed
batch.meta_info["dataloader_kwargs"] = {
"shuffle": self.config.actor_rollout_ref.actor.shuffle,
}
batch.meta_info["compute_loss"] = True
tu.assign_non_tensor(
batch.batch,
temperature=self.config.actor_rollout_ref.rollout.temperature,
global_batch_size=ppo_mini_batch_size,
calculate_entropy=(self.config.actor_rollout_ref.actor.entropy_coeff != 0.0),
)

# update actor
with marked_timer("update_actor", timing_raw, color="red"):
actor_output = self.actor_rollout_wg.update_actor(batch)
Expand All @@ -445,7 +495,11 @@ def fit_agent(self):
with marked_timer("update_weights", timing_raw, color="red"):
self.checkpoint_manager.update_weights(self.global_steps)

actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
# verl 0.7.1 new worker returns TensorDict; extract metrics
if isinstance(actor_output, TensorDict):
actor_output_metrics = reduce_metrics(tu.get(actor_output, "metrics"))
else:
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)

# Visualize some sample trajectories
Expand Down Expand Up @@ -621,9 +675,23 @@ def _validate_agent(self):
try:
# Concatenate all validation batches
combined_batch = DataProto.concat(batches_for_distill)

# Compute old_log_probs for distillation
old_log_prob = self.actor_rollout_wg.compute_log_prob(combined_batch)
combined_batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature

# Follow verl's no-padding pattern for compute_log_prob
cb_td = combined_batch.to_tensordict()
cb_td = left_right_2_no_padding(cb_td)
tu.assign_non_tensor(cb_td, calculate_entropy=True, compute_loss=False,
temperature=self.config.actor_rollout_ref.rollout.temperature)

old_log_prob_output = self.actor_rollout_wg.compute_log_prob(cb_td)
if isinstance(old_log_prob_output, TensorDict):
log_probs = tu.get(old_log_prob_output, "log_probs")
log_probs = no_padding_2_padding(log_probs, cb_td)
old_log_prob = DataProto.from_tensordict(
tu.get_tensordict({"old_log_probs": log_probs.float()})
)
else:
old_log_prob = old_log_prob_output
combined_batch = combined_batch.union(old_log_prob)

# Compute distillation advantages
Expand Down