diff --git a/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py index e5d296a8756..71bfd0af504 100644 --- a/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py +++ b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py @@ -14,7 +14,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any, Optional import numpy as np @@ -32,17 +31,13 @@ from verl.experimental.fully_async_policy.agent_loop.partial_single_turn_agent_loop import PartialSingleTurnAgentLoop from verl.protocol import DataProto from verl.utils.dataset.rl_dataset import RLHFDataset - - -@dataclass -class _FakeTokenOutput: - token_ids: list[int] - log_probs: Optional[list[float]] = None - routed_experts: Any = None - num_preempted: Optional[int] = None +from verl.workers.rollout.replica import TokenOutput class _FakeServerManager: + def __init__(self, *, return_routed_experts: bool = False): + self.return_routed_experts = return_routed_experts + async def generate( self, request_id: str, @@ -51,10 +46,20 @@ async def generate( sampling_params: dict[str, Any], image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> _FakeTokenOutput: + ) -> TokenOutput: del request_id, sampling_params, image_data, video_data # Return a short, deterministic "generation" for testing. - return _FakeTokenOutput(token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0]) + routed_experts = None + if self.return_routed_experts: + num_tokens = len(prompt_ids[-1:] + [11, 12, 13]) + num_layers = 2 + num_experts_per_tok = 2 + routed_experts = np.arange(num_tokens * num_layers * num_experts_per_tok).reshape( + num_tokens, num_layers, num_experts_per_tok + ) + return TokenOutput( + token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0], routed_experts=routed_experts + ) async def generate_for_partial( self, @@ -64,12 +69,21 @@ async def generate_for_partial( sampling_params: dict[str, Any], image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> tuple[list[int], list[float], bool]: + ) -> tuple[TokenOutput, bool]: del request_id, sampling_params, image_data, video_data # Return a short partial generation and "not cancelled". response_ids = prompt_ids[-1:] + [21, 22] response_logprobs = [0.0] * len(response_ids) - return response_ids, response_logprobs, False + routed_experts = None + if self.return_routed_experts: + # Mock routed experts for full sequence (prompt + response) + num_tokens = len(prompt_ids) + len(response_ids) + num_layers = 2 + num_experts_per_tok = 2 + routed_experts = np.arange(num_tokens * num_layers * num_experts_per_tok).reshape( + num_tokens, num_layers, num_experts_per_tok + ) + return TokenOutput(token_ids=response_ids, log_probs=response_logprobs, routed_experts=routed_experts), False class _FakeTokenizer: @@ -258,3 +272,48 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu( assert merged.non_tensor_batch["tool_rewards"][0] == [] assert merged.non_tensor_batch["turn_scores"][1] == [] assert merged.non_tensor_batch["tool_rewards"][1] == [] + + +@pytest.mark.asyncio +async def test_agent_loop_with_routed_experts_on_cpu(): + """Test that routed experts (R3) are properly passed through the agent loop.""" + config = OmegaConf.create( + { + "actor_rollout_ref": {"rollout": {"prompt_length": 16, "response_length": 16}}, + "data": { + "tool_config_path": None, + "apply_chat_template_kwargs": {}, + }, + } + ) + + server_manager = _FakeServerManager(return_routed_experts=True) + tokenizer = _FakeTokenizer() + processor = None + + trainer_config = DictConfigWrap(config) + dataset_config = DictConfigWrap(config.data) + + partial_single_turn = PartialSingleTurnAgentLoop( + trainer_config=trainer_config, + server_manager=server_manager, + tokenizer=tokenizer, + processor=processor, + dataset_cls=RLHFDataset, + dataset_config=dataset_config, + ) + + raw_prompt = [{"role": "user", "content": "hi"}] + sampling_params: dict[str, Any] = {} + + output = await partial_single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt, param_version=0) + + # Verify routed_experts is present and has correct shape + assert output.routed_experts is not None, "routed_experts should not be None when R3 is enabled" + assert isinstance(output.routed_experts, np.ndarray), "routed_experts should be a numpy array" + assert output.routed_experts.ndim == 3, "routed_experts should be 3D: [seq_len, num_layers, num_experts_per_tok]" + # Check that it has the right number of tokens (prompt + response, truncated to response_length) + expected_seq_len = min(len(output.prompt_ids) + len(output.response_ids), 16) + assert output.routed_experts.shape[0] == expected_seq_len, ( + f"routed_experts seq_len should match expected {expected_seq_len}, got {output.routed_experts.shape[0]}" + ) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index f52ead64570..0812abef427 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -588,7 +588,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO total_length = input_ids.shape[1] length, layer_num, topk_num = output.routed_experts.shape if isinstance(output.routed_experts, np.ndarray): - experts_tensor = torch.from_numpy(output.routed_experts) + experts_tensor = torch.tensor(output.routed_experts) elif isinstance(output.routed_experts, torch.Tensor): experts_tensor = output.routed_experts else: diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index c649a2fc3fd..4357877986f 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -77,6 +77,7 @@ def __init__( self.response_ids: list[int] = [] self.response_mask: list[int] = [] self.response_logprobs: list[float] = [] + self.routed_experts: Optional[list[list[int]]] = None self.turn_scores: list[float] = [] self.tool_rewards: list[float] = [] self.user_turns = 0 @@ -85,8 +86,6 @@ def __init__( # Temporary state for tool calls self.tool_calls: list[FunctionCall] = [] - self.routed_experts = None - # Extra fields for dynamic addition, e.g., tool session data self.extra_fields: dict[str, Any] = {} @@ -190,9 +189,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu response_logprobs=agent_data.response_logprobs[: self.response_length] if agent_data.response_logprobs else None, + routed_experts=agent_data.routed_experts[: len(prompt_ids) + self.response_length] + if agent_data.routed_experts is not None + else None, num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, metrics=agent_data.metrics, - routed_experts=agent_data.routed_experts, extra_fields={}, ) output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards}) diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py index 56e12ee7045..f5996f2bf80 100644 --- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py @@ -14,7 +14,7 @@ import asyncio import logging import os -from typing import Any, Optional, Sequence +from typing import Any, Optional import hydra import numpy as np @@ -39,6 +39,7 @@ rollout_trace_attr, rollout_trace_op, ) +from verl.workers.rollout.replica import TokenOutput logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -136,7 +137,7 @@ async def generate_for_partial( sampling_params: dict[str, Any], image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> tuple[list[Any], list[Any], Any] | tuple[Sequence[int], list[float], bool]: + ) -> tuple[TokenOutput, bool]: """Generate tokens from prompt ids, used for async partial. Args: @@ -146,9 +147,8 @@ async def generate_for_partial( Returns: output: A tuple representing the generation output. - - Element 0 (Sequence[int]): Generated response token IDs. - - Element 1 (list[float]): Log probabilities for the response token IDs. - - Element 2 (bool): A flag or status indicating cancellation. + - Element 0 (TokenOutput): Generated tokens and related information (token IDs, logprobs, routed experts). + - Element 1 (bool): A flag or status indicating cancellation. """ server = self._choose_server(request_id) output = await server.generate_for_partial.remote( diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py index 92ea23c6f2c..878a998f5a2 100644 --- a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py @@ -97,13 +97,16 @@ def get_prompt_ids(): # The samples without partial rollout are returned directly. return output with simple_timer("generate_sequences", metrics): - response_ids, response_logprobs, is_cancel = await self.server_manager.generate_for_partial( + token_outputs, is_cancel = await self.server_manager.generate_for_partial( request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=images, video_data=videos, ) + response_ids = token_outputs.token_ids + response_logprobs = token_outputs.log_probs + routed_experts = token_outputs.routed_experts # already contains routed experts for prefix if not output: response_mask = [1] * len(response_ids) else: @@ -120,6 +123,9 @@ def get_prompt_ids(): response_ids=response_ids[: self.response_length], response_mask=response_mask[: self.response_length], response_logprobs=response_logprobs[: self.response_length], + routed_experts=( + routed_experts[: len(prompt_ids) + self.response_length] if routed_experts is not None else None + ), num_turns=2, metrics=metrics, extra_fields={ diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py index 370587f0364..623bbe57aa8 100644 --- a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py @@ -171,13 +171,16 @@ async def _handle_generating_state_partial( with simple_timer("generate_sequences", agent_data.metrics): # partial interface if self.enable_partial_rollout: - response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial( + token_outputs, is_cancel = await self.server_manager.generate_for_partial( request_id=agent_data.request_id, prompt_ids=agent_data.prompt_ids, sampling_params=sampling_params, image_data=agent_data.image_data, video_data=agent_data.video_data, ) + response_ids = token_outputs.token_ids + log_probs = token_outputs.log_probs + routed_experts = token_outputs.routed_experts # already contains routed experts for prefix if is_cancel: # Save the generated parts @@ -203,6 +206,7 @@ async def _handle_generating_state_partial( ) response_ids = output.token_ids log_probs = output.log_probs + routed_experts = output.routed_experts agent_data.assistant_turns += 1 agent_data.response_ids = response_ids @@ -210,6 +214,8 @@ async def _handle_generating_state_partial( agent_data.response_mask += [1] * len(agent_data.response_ids) if log_probs: agent_data.response_logprobs += log_probs + if routed_experts is not None: + agent_data.routed_experts = routed_experts if not ignore_termination and len(agent_data.response_mask) >= self.response_length: return AgentState.TERMINATED @@ -255,6 +261,9 @@ def _build_completed_output(self, agent_data: AgentData, param_version: int) -> response_logprobs=agent_data.response_logprobs[: self.response_length] if agent_data.response_logprobs else None, + routed_experts=agent_data.routed_experts[: len(prompt_ids) + self.response_length] + if agent_data.routed_experts is not None + else None, num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, metrics=agent_data.metrics, extra_fields={}, diff --git a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py b/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py index 89097880f4d..df4e0d85149 100644 --- a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py +++ b/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py @@ -20,7 +20,7 @@ from ray.actor import ActorHandle from verl.workers.config import HFModelConfig, RolloutConfig -from verl.workers.rollout.replica import RolloutMode +from verl.workers.rollout.replica import RolloutMode, TokenOutput from verl.workers.rollout.sglang_rollout.async_sglang_server import ( SGLangHttpServer, SGLangReplica, @@ -117,10 +117,10 @@ async def generate_for_partial( request_id: str, image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> tuple[list[int], list[float], bool]: + ) -> tuple[TokenOutput, bool]: async with self.lock: if self.paused: - return [], [], True + return TokenOutput(token_ids=[], log_probs=[]), True self.req_output[request_id] = None self.cancel_event[request_id] = asyncio.Event() cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) @@ -141,7 +141,7 @@ async def generate_for_partial( if output is None: self.cancel_event.pop(request_id, None) self.req_output.pop(request_id, None) - return [], [], True + return TokenOutput(token_ids=[], log_probs=[]), True meta_info = output.get("meta_info", {}) output_token_logprobs = meta_info.get("output_token_logprobs") @@ -155,11 +155,30 @@ async def generate_for_partial( else: token_ids = list(output["output_ids"]) log_probs = [] + + routed_experts = None + if self.config.enable_rollout_routing_replay: + if self.config.skip_tokenizer_init: + routed_experts = output.get("meta_info", {}).get("routed_experts", None) + else: + from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info + + hf_config = self.model_config.hf_config + if not hasattr(hf_config, "num_hidden_layers") or not hasattr(hf_config, "num_experts_per_tok"): + raise AttributeError( + "enable_rollout_routing_replay is set, but hf_config is missing " + "'num_hidden_layers' or 'num_experts_per_tok'. This feature requires an MoE model " + "configuration that defines these attributes." + ) + routed_experts = extract_routed_experts_from_meta_info(output).reshape( + -1, hf_config.num_hidden_layers, hf_config.num_experts_per_tok + ) + is_cancel = generation_handle not in done self.cancel_event.pop(request_id, None) self.req_output.pop(request_id, None) - return token_ids, log_probs, is_cancel + return TokenOutput(token_ids=token_ids, log_probs=log_probs, routed_experts=routed_experts), is_cancel async def cancel(self): async with self.lock: diff --git a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py index d45dbec890f..492a79cc3f9 100644 --- a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py +++ b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -13,7 +13,7 @@ # limitations under the License. import asyncio import logging -from typing import Any, Optional, Sequence +from typing import Any, Optional import ray from ray.actor import ActorHandle @@ -23,7 +23,7 @@ from verl.utils.tokenizer import normalize_token_ids from verl.workers.config import HFModelConfig, RolloutConfig -from verl.workers.rollout.replica import RolloutMode +from verl.workers.rollout.replica import RolloutMode, TokenOutput from verl.workers.rollout.vllm_rollout.vllm_async_server import ( _qwen2_5_vl_dedup_image_tokens, vLLMHttpServer, @@ -99,11 +99,11 @@ async def generate_for_partial( request_id: str, image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]: + ) -> tuple[TokenOutput, bool]: async with self.lock: if self.paused: # After cancel, all tasks will return directly and wait for the next submission - return [], [], True + return TokenOutput(token_ids=[], log_probs=[]), True self.req_output[request_id]: Optional[RequestOutput] = None self.cancel_event[request_id] = asyncio.Event() cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) @@ -121,7 +121,7 @@ async def generate_for_partial( async with self.lock: if self.req_output[request_id] is None: - return [], [], True + return TokenOutput(token_ids=[], log_probs=[]), True token_ids = self.req_output[request_id].outputs[0].token_ids log_probs: list[float] = [] for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs): @@ -129,10 +129,15 @@ async def generate_for_partial( # but in practice there are multiple. Take the log_prob corresponding to token_id token_id = self.req_output[request_id].outputs[0].token_ids[i] log_probs.append(x[token_id].logprob) + routed_experts = getattr(self.req_output[request_id].outputs[0], "routed_experts", None) is_cancel = generation_handle not in done self.cancel_event.pop(request_id, None) self.req_output.pop(request_id, None) - return token_ids, log_probs, is_cancel + return TokenOutput( + token_ids=token_ids, + log_probs=log_probs, + routed_experts=routed_experts, + ), is_cancel async def cancel(self): async with self.lock: