Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions rllm/experimental/rollout/rollout_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput
from rllm.tools.tool_base import ToolCall

if TYPE_CHECKING:
from rllm.experimental.rollout.types import TokenInput, Tokenizer, TokenOutput
from rllm.parser import ChatTemplateParser
from rllm.tools.tool_base import ToolCall


@dataclass
Expand Down Expand Up @@ -43,6 +42,8 @@ def to_dict(self):

@classmethod
def from_dict(cls, data: dict):
from rllm.tools.tool_base import ToolCall

return cls(
text=data.get("text"),
content=data.get("content"),
Expand Down
38 changes: 0 additions & 38 deletions rllm/experimental/verl/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,10 @@

logger = logging.getLogger(__name__)

_VERL_ACTOR_PATCHED = False
_VERL_DYNAMIC_BATCH_PATCHED = False
_VLLM_SDK_PATCHED = False


# ---------------------------------------------------------------------------
# Verl actor: per-call policy loss mode override
# ---------------------------------------------------------------------------


def patch_verl_actor_for_loss_override() -> None:
"""Patch ``DataParallelPPOActor.update_policy`` to support per-call loss mode.

When ``data.meta_info`` contains ``"policy_loss_mode_override"``, the
actor temporarily uses that loss mode instead of the one baked into
``self.config.policy_loss.loss_mode``. The original config value is
restored after the call (even on exception).
"""
global _VERL_ACTOR_PATCHED
if _VERL_ACTOR_PATCHED:
return

from verl.workers.actor.dp_actor import DataParallelPPOActor

_original_update_policy = DataParallelPPOActor.update_policy

def _patched_update_policy(self, data):
override = data.meta_info.get("policy_loss_mode_override")
if override is not None:
original = self.config.policy_loss.get("loss_mode", "vanilla")
self.config.policy_loss["loss_mode"] = override
try:
return _original_update_policy(self, data)
finally:
self.config.policy_loss["loss_mode"] = original
return _original_update_policy(self, data)

DataParallelPPOActor.update_policy = _patched_update_policy
_VERL_ACTOR_PATCHED = True
logger.info("Patched DataParallelPPOActor.update_policy for per-call loss mode override")


# ---------------------------------------------------------------------------
# Verl dynamic batch: sync micro-batch counts across DP ranks
# ---------------------------------------------------------------------------
Expand Down
20 changes: 15 additions & 5 deletions rllm/experimental/verl/verl_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ def compute_advantage_verl(batch: DataProto, config: DictConfig) -> tuple[DataPr
is_last_step = batch.non_tensor_batch["is_last_step"]
last_step_indices = np.where(is_last_step)[0]
not_last_step_indices = np.where(~is_last_step)[0]
non_last_step_batch = batch.select_idxs(not_last_step_indices)
batch = batch.select_idxs(last_step_indices)

batch = compute_advantage(
batch,
adv_kwargs = dict(
adv_estimator=config.algorithm.adv_estimator,
gamma=config.algorithm.gamma,
lam=config.algorithm.lam,
Expand All @@ -44,6 +41,17 @@ def compute_advantage_verl(batch: DataProto, config: DictConfig) -> tuple[DataPr
config=config.algorithm,
)

if len(not_last_step_indices) == 0:
# All steps are last steps (e.g. single-step trajectories) — compute directly, no broadcast needed
batch = compute_advantage(batch, **adv_kwargs)
return batch, metrics

# Multi-step: split by last step, compute advantages on last steps, broadcast to earlier steps
non_last_step_batch = batch.select_idxs(not_last_step_indices)
batch = batch.select_idxs(last_step_indices)

batch = compute_advantage(batch, **adv_kwargs)

_stepwise_advantage_broadcast(batch, non_last_step_batch, config)
batch = DataProto.concat([batch, non_last_step_batch])

Expand Down Expand Up @@ -73,7 +81,9 @@ def _stepwise_advantage_broadcast(last_step_batch: DataProto, non_last_step_batc

traj_ep_to_scalar_adv[(traj_id, eps_id)] = scalar

scalar_rows = torch.stack([torch.full_like(tgt_mask[i], fill_value=traj_ep_to_scalar_adv[(traj_id, eps_id)], dtype=torch.float32) for i, (traj_id, eps_id) in enumerate(zip(tgt_traj_ids, tgt_eps_ids, strict=False))])
scalar_rows = torch.stack(
[torch.full_like(tgt_mask[i], fill_value=traj_ep_to_scalar_adv[(traj_id, eps_id)], dtype=torch.float32) for i, (traj_id, eps_id) in enumerate(zip(tgt_traj_ids, tgt_eps_ids, strict=False))]
)

final_advantage = scalar_rows * tgt_mask
non_last_step_batch.batch["advantages"] = final_advantage
Expand Down
Loading
Loading