Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Optional

import numpy as np
Expand All @@ -32,17 +31,13 @@
from verl.experimental.fully_async_policy.agent_loop.partial_single_turn_agent_loop import PartialSingleTurnAgentLoop
from verl.protocol import DataProto
from verl.utils.dataset.rl_dataset import RLHFDataset


@dataclass
class _FakeTokenOutput:
token_ids: list[int]
log_probs: Optional[list[float]] = None
routed_experts: Any = None
num_preempted: Optional[int] = None
from verl.workers.rollout.replica import TokenOutput


class _FakeServerManager:
def __init__(self, *, return_routed_experts: bool = False):
self.return_routed_experts = return_routed_experts

async def generate(
self,
request_id: str,
Expand All @@ -51,10 +46,20 @@ async def generate(
sampling_params: dict[str, Any],
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> _FakeTokenOutput:
) -> TokenOutput:
del request_id, sampling_params, image_data, video_data
# Return a short, deterministic "generation" for testing.
return _FakeTokenOutput(token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0])
routed_experts = None
if self.return_routed_experts:
num_tokens = len(prompt_ids[-1:] + [11, 12, 13])
num_layers = 2
num_experts_per_tok = 2
routed_experts = np.arange(num_tokens * num_layers * num_experts_per_tok).reshape(
num_tokens, num_layers, num_experts_per_tok
)
return TokenOutput(
token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0], routed_experts=routed_experts
)

async def generate_for_partial(
self,
Expand All @@ -64,12 +69,21 @@ async def generate_for_partial(
sampling_params: dict[str, Any],
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> tuple[list[int], list[float], bool]:
) -> tuple[TokenOutput, bool]:
del request_id, sampling_params, image_data, video_data
# Return a short partial generation and "not cancelled".
response_ids = prompt_ids[-1:] + [21, 22]
response_logprobs = [0.0] * len(response_ids)
return response_ids, response_logprobs, False
routed_experts = None
if self.return_routed_experts:
# Mock routed experts for full sequence (prompt + response)
num_tokens = len(prompt_ids) + len(response_ids)
num_layers = 2
num_experts_per_tok = 2
routed_experts = np.arange(num_tokens * num_layers * num_experts_per_tok).reshape(
num_tokens, num_layers, num_experts_per_tok
)
return TokenOutput(token_ids=response_ids, log_probs=response_logprobs, routed_experts=routed_experts), False


class _FakeTokenizer:
Expand Down Expand Up @@ -258,3 +272,48 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu(
assert merged.non_tensor_batch["tool_rewards"][0] == []
assert merged.non_tensor_batch["turn_scores"][1] == []
assert merged.non_tensor_batch["tool_rewards"][1] == []


@pytest.mark.asyncio
async def test_agent_loop_with_routed_experts_on_cpu():
"""Test that routed experts (R3) are properly passed through the agent loop."""
config = OmegaConf.create(
{
"actor_rollout_ref": {"rollout": {"prompt_length": 16, "response_length": 16}},
"data": {
"tool_config_path": None,
"apply_chat_template_kwargs": {},
},
}
)

server_manager = _FakeServerManager(return_routed_experts=True)
tokenizer = _FakeTokenizer()
processor = None

trainer_config = DictConfigWrap(config)
dataset_config = DictConfigWrap(config.data)

partial_single_turn = PartialSingleTurnAgentLoop(
trainer_config=trainer_config,
server_manager=server_manager,
tokenizer=tokenizer,
processor=processor,
dataset_cls=RLHFDataset,
dataset_config=dataset_config,
)

raw_prompt = [{"role": "user", "content": "hi"}]
sampling_params: dict[str, Any] = {}

output = await partial_single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt, param_version=0)

# Verify routed_experts is present and has correct shape
assert output.routed_experts is not None, "routed_experts should not be None when R3 is enabled"
assert isinstance(output.routed_experts, np.ndarray), "routed_experts should be a numpy array"
assert output.routed_experts.ndim == 3, "routed_experts should be 3D: [seq_len, num_layers, num_experts_per_tok]"
# Check that it has the right number of tokens (prompt + response, truncated to response_length)
expected_seq_len = min(len(output.prompt_ids) + len(output.response_ids), 16)
assert output.routed_experts.shape[0] == expected_seq_len, (
f"routed_experts seq_len should match expected {expected_seq_len}, got {output.routed_experts.shape[0]}"
)
2 changes: 1 addition & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
total_length = input_ids.shape[1]
length, layer_num, topk_num = output.routed_experts.shape
if isinstance(output.routed_experts, np.ndarray):
experts_tensor = torch.from_numpy(output.routed_experts)
experts_tensor = torch.tensor(output.routed_experts)
elif isinstance(output.routed_experts, torch.Tensor):
experts_tensor = output.routed_experts
else:
Expand Down
7 changes: 4 additions & 3 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self.response_ids: list[int] = []
self.response_mask: list[int] = []
self.response_logprobs: list[float] = []
self.routed_experts: Optional[list[list[int]]] = None
self.turn_scores: list[float] = []
self.tool_rewards: list[float] = []
self.user_turns = 0
Expand All @@ -85,8 +86,6 @@ def __init__(
# Temporary state for tool calls
self.tool_calls: list[FunctionCall] = []

self.routed_experts = None

# Extra fields for dynamic addition, e.g., tool session data
self.extra_fields: dict[str, Any] = {}

Expand Down Expand Up @@ -190,9 +189,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
routed_experts=agent_data.routed_experts[: len(prompt_ids) + self.response_length]
if agent_data.routed_experts is not None
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
routed_experts=agent_data.routed_experts,
extra_fields={},
)
output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
Expand Down
10 changes: 5 additions & 5 deletions verl/experimental/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import asyncio
import logging
import os
from typing import Any, Optional, Sequence
from typing import Any, Optional

import hydra
import numpy as np
Expand All @@ -39,6 +39,7 @@
rollout_trace_attr,
rollout_trace_op,
)
from verl.workers.rollout.replica import TokenOutput

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -136,7 +137,7 @@ async def generate_for_partial(
sampling_params: dict[str, Any],
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> tuple[list[Any], list[Any], Any] | tuple[Sequence[int], list[float], bool]:
) -> tuple[TokenOutput, bool]:
"""Generate tokens from prompt ids, used for async partial.

Args:
Expand All @@ -146,9 +147,8 @@ async def generate_for_partial(

Returns:
output: A tuple representing the generation output.
- Element 0 (Sequence[int]): Generated response token IDs.
- Element 1 (list[float]): Log probabilities for the response token IDs.
- Element 2 (bool): A flag or status indicating cancellation.
- Element 0 (TokenOutput): Generated tokens and related information (token IDs, logprobs, routed experts).
- Element 1 (bool): A flag or status indicating cancellation.
"""
server = self._choose_server(request_id)
output = await server.generate_for_partial.remote(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ def get_prompt_ids():
# The samples without partial rollout are returned directly.
return output
with simple_timer("generate_sequences", metrics):
response_ids, response_logprobs, is_cancel = await self.server_manager.generate_for_partial(
token_outputs, is_cancel = await self.server_manager.generate_for_partial(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
image_data=images,
video_data=videos,
)
response_ids = token_outputs.token_ids
response_logprobs = token_outputs.log_probs
routed_experts = token_outputs.routed_experts # already contains routed experts for prefix
if not output:
response_mask = [1] * len(response_ids)
else:
Expand All @@ -120,6 +123,9 @@ def get_prompt_ids():
response_ids=response_ids[: self.response_length],
response_mask=response_mask[: self.response_length],
response_logprobs=response_logprobs[: self.response_length],
routed_experts=(
routed_experts[: len(prompt_ids) + self.response_length] if routed_experts is not None else None
),
num_turns=2,
metrics=metrics,
extra_fields={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ async def _handle_generating_state_partial(
with simple_timer("generate_sequences", agent_data.metrics):
# partial interface
if self.enable_partial_rollout:
response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial(
token_outputs, is_cancel = await self.server_manager.generate_for_partial(
request_id=agent_data.request_id,
prompt_ids=agent_data.prompt_ids,
sampling_params=sampling_params,
image_data=agent_data.image_data,
video_data=agent_data.video_data,
)
response_ids = token_outputs.token_ids
log_probs = token_outputs.log_probs
routed_experts = token_outputs.routed_experts # already contains routed experts for prefix
Comment on lines +181 to +183
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The routed_experts data should be saved to agent_data immediately after it is received from the server. In the current implementation, if the generation is cancelled (line 185) or terminates early due to length (line 192), the latest routed_experts information is not preserved in agent_data. This will cause the expert routing data to be lost when the task is resumed or completed, which is critical for R3 (Routing Replay) correctness.

Suggested change
response_ids = token_outputs.token_ids
log_probs = token_outputs.log_probs
routed_experts = token_outputs.routed_experts # already contains routed experts for prefix
response_ids = token_outputs.token_ids
log_probs = token_outputs.log_probs
routed_experts = token_outputs.routed_experts # already contains routed experts for prefix
if routed_experts is not None:
agent_data.routed_experts = routed_experts

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

routed_experts contains the routed experts for both input and output. As described in the PR:

In the case of partial rollouts, the routed experts returned are those for the last part of the generation. This is not ideal, as we should keep track of the returned experts for each part and force the rollout engine to use the previously routed experts. However, rollout engines such as vLLM do not currently support this, so we settle for this implementation.


if is_cancel:
# Save the generated parts
Expand All @@ -203,13 +206,16 @@ async def _handle_generating_state_partial(
)
response_ids = output.token_ids
log_probs = output.log_probs
routed_experts = output.routed_experts

agent_data.assistant_turns += 1
agent_data.response_ids = response_ids
agent_data.prompt_ids += agent_data.response_ids
agent_data.response_mask += [1] * len(agent_data.response_ids)
if log_probs:
agent_data.response_logprobs += log_probs
if routed_experts is not None:
agent_data.routed_experts = routed_experts

if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
return AgentState.TERMINATED
Expand Down Expand Up @@ -255,6 +261,9 @@ def _build_completed_output(self, agent_data: AgentData, param_version: int) ->
response_logprobs=agent_data.response_logprobs[: self.response_length]
if agent_data.response_logprobs
else None,
routed_experts=agent_data.routed_experts[: len(prompt_ids) + self.response_length]
if agent_data.routed_experts is not None
else None,
num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
metrics=agent_data.metrics,
extra_fields={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ray.actor import ActorHandle

from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode
from verl.workers.rollout.replica import RolloutMode, TokenOutput
from verl.workers.rollout.sglang_rollout.async_sglang_server import (
SGLangHttpServer,
SGLangReplica,
Expand Down Expand Up @@ -117,10 +117,10 @@ async def generate_for_partial(
request_id: str,
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> tuple[list[int], list[float], bool]:
) -> tuple[TokenOutput, bool]:
async with self.lock:
if self.paused:
return [], [], True
return TokenOutput(token_ids=[], log_probs=[]), True
self.req_output[request_id] = None
self.cancel_event[request_id] = asyncio.Event()
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
Expand All @@ -141,7 +141,7 @@ async def generate_for_partial(
if output is None:
self.cancel_event.pop(request_id, None)
self.req_output.pop(request_id, None)
return [], [], True
return TokenOutput(token_ids=[], log_probs=[]), True
meta_info = output.get("meta_info", {})
output_token_logprobs = meta_info.get("output_token_logprobs")

Expand All @@ -155,11 +155,30 @@ async def generate_for_partial(
else:
token_ids = list(output["output_ids"])
log_probs = []

routed_experts = None
if self.config.enable_rollout_routing_replay:
if self.config.skip_tokenizer_init:
routed_experts = output.get("meta_info", {}).get("routed_experts", None)
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info

hf_config = self.model_config.hf_config
if not hasattr(hf_config, "num_hidden_layers") or not hasattr(hf_config, "num_experts_per_tok"):
raise AttributeError(
"enable_rollout_routing_replay is set, but hf_config is missing "
"'num_hidden_layers' or 'num_experts_per_tok'. This feature requires an MoE model "
"configuration that defines these attributes."
)
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, hf_config.num_hidden_layers, hf_config.num_experts_per_tok
)

is_cancel = generation_handle not in done
self.cancel_event.pop(request_id, None)
self.req_output.pop(request_id, None)

return token_ids, log_probs, is_cancel
return TokenOutput(token_ids=token_ids, log_probs=log_probs, routed_experts=routed_experts), is_cancel

async def cancel(self):
async with self.lock:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import asyncio
import logging
from typing import Any, Optional, Sequence
from typing import Any, Optional

import ray
from ray.actor import ActorHandle
Expand All @@ -23,7 +23,7 @@

from verl.utils.tokenizer import normalize_token_ids
from verl.workers.config import HFModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode
from verl.workers.rollout.replica import RolloutMode, TokenOutput
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
_qwen2_5_vl_dedup_image_tokens,
vLLMHttpServer,
Expand Down Expand Up @@ -99,11 +99,11 @@ async def generate_for_partial(
request_id: str,
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:
) -> tuple[TokenOutput, bool]:
async with self.lock:
if self.paused:
# After cancel, all tasks will return directly and wait for the next submission
return [], [], True
return TokenOutput(token_ids=[], log_probs=[]), True
self.req_output[request_id]: Optional[RequestOutput] = None
self.cancel_event[request_id] = asyncio.Event()
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
Expand All @@ -121,18 +121,23 @@ async def generate_for_partial(

async with self.lock:
if self.req_output[request_id] is None:
return [], [], True
return TokenOutput(token_ids=[], log_probs=[]), True
token_ids = self.req_output[request_id].outputs[0].token_ids
log_probs: list[float] = []
for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):
# In sampling_params, logprobs is set to 1, which should return 1,
# but in practice there are multiple. Take the log_prob corresponding to token_id
token_id = self.req_output[request_id].outputs[0].token_ids[i]
log_probs.append(x[token_id].logprob)
routed_experts = getattr(self.req_output[request_id].outputs[0], "routed_experts", None)
is_cancel = generation_handle not in done
self.cancel_event.pop(request_id, None)
self.req_output.pop(request_id, None)
return token_ids, log_probs, is_cancel
return TokenOutput(
token_ids=token_ids,
log_probs=log_probs,
routed_experts=routed_experts,
), is_cancel

async def cancel(self):
async with self.lock:
Expand Down