diff --git a/tests/checkpoint_engine/test_special_server_adapter.py b/tests/checkpoint_engine/test_special_server_adapter.py index 11f99eb4875..2932667df65 100644 --- a/tests/checkpoint_engine/test_special_server_adapter.py +++ b/tests/checkpoint_engine/test_special_server_adapter.py @@ -38,7 +38,7 @@ def init_config() -> DictConfig: config = compose( config_name="ppo_trainer", overrides=[ - "+async_training.partial_rollout_resume=True", + "+async_training.partial_rollout=True", ], ) @@ -73,8 +73,8 @@ async def _run_update_weights_with_global_steps_none( assert output.stop_reason not in ("aborted", "abort"), ( f"output.stop_reason is {output.stop_reason}, expected not abort" ) - assert output.extra_info["global_steps"] is None, ( - f"output.extra_info['global_steps'] is {output.extra_info['global_steps']}, expected None" + assert output.extra_fields["global_steps"] is None, ( + f"output.extra_fields['global_steps'] is {output.extra_fields['global_steps']}, expected None" ) print("========== [update_weights with global_steps=None] ==========") print("[RESPONSE]", tokenizer.decode(output.token_ids, skip_special_tokens=True)) @@ -112,7 +112,7 @@ async def _run_server_manager_without_resume( outputs = await asyncio.gather(*tasks) expected_steps = global_steps - 1 for output in outputs: - global_steps = output.extra_info["global_steps"] + global_steps = output.extra_fields["global_steps"] assert output.stop_reason in ("aborted", "abort"), ( f"output.stop_reason is {output.stop_reason}, expected in abort" ) @@ -156,8 +156,8 @@ async def _run_server_manager_with_resume( outputs = await asyncio.gather(*tasks) expected_min_steps = initial_steps - 1 for output in outputs: - min_global_steps = output.extra_info["min_global_steps"] - max_global_steps = output.extra_info["max_global_steps"] + min_global_steps = output.extra_fields["min_global_steps"] + max_global_steps = output.extra_fields["max_global_steps"] assert min_global_steps == expected_min_steps, ( f"output.min_global_steps is {min_global_steps}, expected {expected_min_steps}" ) diff --git a/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py index e5d296a8756..c56ae125718 100644 --- a/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py +++ b/tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py @@ -14,7 +14,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any, Optional import numpy as np @@ -29,17 +28,8 @@ _InternalAgentLoopOutput, ) from verl.experimental.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop -from verl.experimental.fully_async_policy.agent_loop.partial_single_turn_agent_loop import PartialSingleTurnAgentLoop -from verl.protocol import DataProto from verl.utils.dataset.rl_dataset import RLHFDataset - - -@dataclass -class _FakeTokenOutput: - token_ids: list[int] - log_probs: Optional[list[float]] = None - routed_experts: Any = None - num_preempted: Optional[int] = None +from verl.workers.rollout.replica import TokenOutput class _FakeServerManager: @@ -51,10 +41,10 @@ async def generate( sampling_params: dict[str, Any], image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, - ) -> _FakeTokenOutput: + ) -> TokenOutput: del request_id, sampling_params, image_data, video_data # Return a short, deterministic "generation" for testing. - return _FakeTokenOutput(token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0]) + return TokenOutput(token_ids=prompt_ids[-1:] + [11, 12, 13], log_probs=[0.0, 0.0, 0.0, 0.0]) async def generate_for_partial( self, @@ -173,54 +163,30 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu( dataset_cls=RLHFDataset, data_config=data_config, ) - partial_single_turn = PartialSingleTurnAgentLoop( - trainer_config=trainer_config, - server_manager=server_manager, - tokenizer=tokenizer, - processor=processor, - dataset_cls=RLHFDataset, - data_config=data_config, - ) raw_prompt = [{"role": "user", "content": "hi"}] sampling_params: dict[str, Any] = {} - out_a = await single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt) - out_b = await partial_single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt, param_version=0) + out = await single_turn.run(sampling_params=sampling_params, raw_prompt=raw_prompt) # Agent loop outputs should always contain these fields with consistent types. - assert out_a.extra_fields["turn_scores"] == [] - assert out_a.extra_fields["tool_rewards"] == [] - assert out_b.extra_fields["turn_scores"] == [] - assert out_b.extra_fields["tool_rewards"] == [] - - prompt_len = max(len(out_a.prompt_ids), len(out_b.prompt_ids)) - response_len = max(len(out_a.response_ids), len(out_b.response_ids)) + assert out.extra_fields["turn_scores"] == [] + assert out.extra_fields["tool_rewards"] == [] internal_a = _to_internal( - output_prompt_ids=out_a.prompt_ids, - output_response_ids=out_a.response_ids, - output_response_mask=out_a.response_mask, - metrics=out_a.metrics, - extra_fields=out_a.extra_fields, - num_turns=out_a.num_turns, - prompt_len=prompt_len, - response_len=response_len, - ) - internal_b = _to_internal( - output_prompt_ids=out_b.prompt_ids, - output_response_ids=out_b.response_ids, - output_response_mask=out_b.response_mask, - metrics=out_b.metrics, - extra_fields=out_b.extra_fields, - num_turns=out_b.num_turns, - prompt_len=prompt_len, - response_len=response_len, + output_prompt_ids=out.prompt_ids, + output_response_ids=out.response_ids, + output_response_mask=out.response_mask, + metrics=out.metrics, + extra_fields=out.extra_fields, + num_turns=out.num_turns, + prompt_len=len(out.prompt_ids), + response_len=len(out.response_ids), ) # Mimic two "worker chunks" and concatenate as in training. dummy_worker = type("_DummyWorker", (), {"reward_loop_worker_handles": None})() - chunk_a = AgentLoopWorker._postprocess( + merged = AgentLoopWorker._postprocess( dummy_worker, inputs=[internal_a], input_non_tensor_batch={ @@ -228,33 +194,21 @@ async def test_agent_loop_extra_fields_schema_stable_for_training_concat_on_cpu( "agent_name": np.array(["single_turn_agent"], dtype=object), }, ) - chunk_b = AgentLoopWorker._postprocess( - dummy_worker, - inputs=[internal_b], - input_non_tensor_batch={ - "index": np.array([1], dtype=object), - "agent_name": np.array(["partial_single_turn_agent"], dtype=object), - }, - ) - merged: DataProto = DataProto.concat([chunk_a, chunk_b]) # Stable schema: present regardless of which loop produced a sample. stable_keys = ( "turn_scores", "tool_rewards", - "is_cancel", - "param_version_start", - "param_version_end", + "min_global_steps", + "max_global_steps", "extras", ) for key in stable_keys: assert key in merged.non_tensor_batch, f"missing key in merged batch: {key}" - assert merged.non_tensor_batch[key].shape == (2,), ( + assert merged.non_tensor_batch[key].shape == (1,), ( f"invalid shape for {key}: {merged.non_tensor_batch[key].shape}" ) # And the list-typed fields are actually lists (not missing / scalar). assert merged.non_tensor_batch["turn_scores"][0] == [] assert merged.non_tensor_batch["tool_rewards"][0] == [] - assert merged.non_tensor_batch["turn_scores"][1] == [] - assert merged.non_tensor_batch["tool_rewards"][1] == [] diff --git a/tests/experimental/agent_loop/test_multi_modal.py b/tests/experimental/agent_loop/test_multi_modal.py index 12a87c0efe1..c552f4c1bb2 100644 --- a/tests/experimental/agent_loop/test_multi_modal.py +++ b/tests/experimental/agent_loop/test_multi_modal.py @@ -434,137 +434,3 @@ def test_multimodal_single_turn_agent(init_config): print("Single turn multimodal test passed!") ray.shutdown() - - -def test_multimodal_partial_single_turn_agent(init_config): - """Test partial single turn agent loop with multimodal inputs using Qwen VL model.""" - - # TODO(baiyan): - # see verl/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py for more details. - # if use_correct_processor=True, the test will pass but the async training will hang, so I disable this test - # for now - - return - - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - }, - ignore_reinit_error=True, - ) - from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager - - # =========================== 1. Init rollout manager =========================== - n = 2 - init_config.actor_rollout_ref.rollout.n = n - init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1 - init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1 - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - agent_loop_manager = loop.run_until_complete(FullyAsyncAgentLoopManager.create(init_config)) - - # =========================== 2. Generate sequences with multimodal prompts =========================== - # Create a simple test image - test_image = Image.new("RGB", (256, 256), (200, 100, 50)) - test_image2 = Image.new("RGB", (512, 512), (100, 150, 200)) - - raw_prompts = [ - [ - {"role": "user", "content": "What is the capital of France?"}, - ], - [ - { - "role": "user", - "content": [ - {"type": "image", "image": test_image}, - {"type": "text", "text": "What do you see in this image?"}, - ], - }, - ], - [ - { - "role": "system", - "content": "You are Qwen VL, a helpful multimodal assistant.", - }, - { - "role": "user", - "content": [ - {"type": "image", "image": test_image2}, - {"type": "text", "text": "Analyze the colors in this image."}, - ], - }, - ], - ] - - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), - "agent_name": np.array(["partial_single_turn_agent"] * len(raw_prompts)), - "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), - "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), - }, - ) - - batch = batch.repeat(n) - result = agent_loop_manager.generate_sequences(prompts=batch) - assert len(result) == len(raw_prompts) * n - - # Check turns - all should be single turn (2: user + assistant) - num_turns = result.non_tensor_batch["__num_turns__"] - print(f"num_turns: {num_turns}") - for i in range(len(num_turns)): - assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}" - - # Verify responses - tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) - prompts = result.batch["prompts"] - responses = result.batch["responses"] - response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" - - # Check for image pads in prompts - image_pad_count = 0 - for i in range(len(prompts)): - prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist() - prompt_text = tokenizer.decode(prompt_ids) - - # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images) - sample_idx = i // n - has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text - - print("=========================") - print(f"Sample {i} (original prompt index: {sample_idx}):") - print(f"Prompt length: {len(prompt_ids)} tokens") - print(f"Has image_pad: {has_image_pad}") - - if sample_idx != 0: # Samples 1 and 2 should have images - if has_image_pad: - image_pad_count += 1 - # Count the number of image_pad tokens - num_image_pads = prompt_text.count("<|image_pad|>") - print(f"Number of <|image_pad|> tokens: {num_image_pads}") - else: - print("WARNING: Expected image_pad but not found!") - - # Show first 200 chars of prompt - print(f"Prompt text (first 200 chars): {prompt_text[:200]}...") - - for i in range(len(responses)): - valid_tokens = responses[i][response_mask[i].bool()] - response_text = tokenizer.decode(valid_tokens) - print(f"Sample {i} response: {response_text[:100]}...") - - # Verify that we found image pads in multimodal samples - expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times - print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected") - assert image_pad_count > 0, "No image_pad tokens found in multimodal samples!" - - print("Partial single turn multimodal test passed!") - ray.shutdown() diff --git a/tests/special_e2e/run_fully_async_policy.sh b/tests/special_e2e/run_fully_async_policy.sh index 01d807ba63a..11db3d2d0de 100644 --- a/tests/special_e2e/run_fully_async_policy.sh +++ b/tests/special_e2e/run_fully_async_policy.sh @@ -58,9 +58,10 @@ n_resp_per_prompt=16 train_prompt_mini_bsz=16 total_rollout_steps=$(((128))) test_freq=-1 -staleness_threshold=0.1 +staleness_threshold=0.5 trigger_parameter_sync_step=4 partial_rollout=True +use_trainer_do_validate=False exp_name="$(basename "${MODEL_ID,,}")-fully-async-policy-${ACTOR_STRATEGY}-minimal" @@ -127,13 +128,13 @@ common_params=( rollout.nnodes=1 rollout.n_gpus_per_node=${n_gpus_rollout} rollout.total_rollout_steps=${total_rollout_steps} - rollout.total_epochs=2 - rollout.test_freq=${test_freq} + trainer.total_epochs=2 + trainer.test_freq=${test_freq} # Fully async specific configurations async_training.staleness_threshold=${staleness_threshold} async_training.partial_rollout="${partial_rollout}" async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" - # GPU specific configurations + async_training.use_trainer_do_validate=${use_trainer_do_validate} actor_rollout_ref.rollout.checkpoint_engine.backend='nccl' actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=1024 ) diff --git a/tests/special_npu/run_fully_async_policy.sh b/tests/special_npu/run_fully_async_policy.sh index 71b544a3c19..7bf04ad92fd 100644 --- a/tests/special_npu/run_fully_async_policy.sh +++ b/tests/special_npu/run_fully_async_policy.sh @@ -127,8 +127,8 @@ common_params=( rollout.nnodes=1 rollout.n_gpus_per_node=${n_gpus_rollout} rollout.total_rollout_steps=${total_rollout_steps} - rollout.total_epochs=2 - rollout.test_freq=${test_freq} + trainer.total_epochs=2 + trainer.test_freq=${test_freq} # Fully async specific configurations async_training.staleness_threshold=${staleness_threshold} async_training.partial_rollout="${partial_rollout}" diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index 46461590e94..c220227d05b 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -55,7 +55,6 @@ NCCL_KEYWORD_CHECK_WHITELIST = [ "verl/utils/device.py", "verl/third_party/sglang/parallel_state.py", # appear in default backend - "verl/recipe/fully_async_policy/param_sync.py", # fully_async_policy in default backend ] SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST diff --git a/tests/special_sanity/check_pr_title.py b/tests/special_sanity/check_pr_title.py index 1153d9d77af..df316d3d080 100644 --- a/tests/special_sanity/check_pr_title.py +++ b/tests/special_sanity/check_pr_title.py @@ -23,6 +23,7 @@ allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg", "reward"] +allowed_modules += ["fully_async", "one_step_off"] allowed_types = ["feat", "fix", "refactor", "chore", "test"] # Check for [1/N] prefix and extract the rest of the title diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 5383ae4a2a5..465cb78ec19 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -154,7 +154,7 @@ async def generate( """ server_id, server = await self._acquire_server(request_id) try: - output = await server.generate.remote( + output: TokenOutput = await server.generate.remote( request_id=uuid4().hex, # use new request_id for each turn prompt_ids=prompt_ids, sampling_params=sampling_params, @@ -839,9 +839,8 @@ def _postprocess( default_extra_keys = { "turn_scores", "tool_rewards", - "is_cancel", - "param_version_start", - "param_version_end", + "min_global_steps", + "max_global_steps", "extras", } all_keys = set(key for input_item in inputs for key in input_item.extra_fields) | default_extra_keys diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 6ad3aa429b3..d45082f5fa1 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -18,6 +18,7 @@ from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register from verl.utils.profiler import simple_timer +from verl.workers.rollout.replica import TokenOutput logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -50,7 +51,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu # 3. generate sequences metrics = {} with simple_timer("generate_sequences", metrics): - output = await self.server_manager.generate( + output: TokenOutput = await self.server_manager.generate( request_id=uuid4().hex, prompt_ids=prompt_ids, sampling_params=sampling_params, @@ -61,7 +62,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1 response_mask = [1] * len(output.token_ids) - output = AgentLoopOutput( + output: AgentLoopOutput = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=output.token_ids[: self.response_length], response_mask=response_mask[: self.response_length], @@ -74,6 +75,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu multi_modal_data=multi_modal_data, num_turns=2, metrics=metrics, + extra_fields=output.extra_fields, ) # keeping the schema consistent with tool_agent_loop diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index c649a2fc3fd..007a4c46244 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -35,6 +35,7 @@ from verl.tools.utils.tool_registry import initialize_tools_from_config from verl.utils.profiler import simple_timer from verl.utils.rollout_trace import rollout_trace_op +from verl.workers.rollout.replica import TokenOutput logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -182,7 +183,8 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu multi_modal_data["images"] = agent_data.image_data if agent_data.video_data is not None: multi_modal_data["videos"] = agent_data.video_data - output = AgentLoopOutput( + + output: AgentLoopOutput = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=response_ids[: self.response_length], response_mask=agent_data.response_mask[: self.response_length], @@ -193,7 +195,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, metrics=agent_data.metrics, routed_experts=agent_data.routed_experts, - extra_fields={}, + extra_fields=agent_data.extra_fields, ) output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards}) return output @@ -216,7 +218,7 @@ async def _handle_generating_state( add_messages: list[dict[str, Any]] = [] with simple_timer("generate_sequences", agent_data.metrics): - output = await self.server_manager.generate( + output: TokenOutput = await self.server_manager.generate( request_id=agent_data.request_id, prompt_ids=agent_data.prompt_ids, sampling_params=sampling_params, @@ -230,6 +232,14 @@ async def _handle_generating_state( else: agent_data.metrics["num_preempted"] += output.num_preempted if output.num_preempted is not None else 0 + if not agent_data.extra_fields: + agent_data.extra_fields.update(output.extra_fields) + else: + # Multi-round calls, only update the maximum max_global_steps. + max_global_steps = output.extra_fields.get("max_global_steps", None) + if max_global_steps: + agent_data.extra_fields["max_global_steps"] = max_global_steps + agent_data.assistant_turns += 1 agent_data.response_ids = output.token_ids agent_data.prompt_ids += agent_data.response_ids diff --git a/verl/experimental/fully_async_policy/agent_loop/__init__.py b/verl/experimental/fully_async_policy/agent_loop/__init__.py index ef46df0e529..1599c4a9ef9 100644 --- a/verl/experimental/fully_async_policy/agent_loop/__init__.py +++ b/verl/experimental/fully_async_policy/agent_loop/__init__.py @@ -13,8 +13,5 @@ # limitations under the License. from .agent_loop import FullyAsyncAgentLoopManager -from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop -from .partial_tool_agent_loop import AsyncPartialToolAgentLoop -_ = [PartialSingleTurnAgentLoop, AsyncPartialToolAgentLoop] __all__ = [FullyAsyncAgentLoopManager] diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py index bea24f7a3fb..16b670ba579 100644 --- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py +++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py @@ -14,29 +14,22 @@ import asyncio import logging import os -from typing import Any, Optional, Sequence +from typing import Any, Optional -import hydra -import numpy as np import ray import torch from omegaconf import DictConfig from verl.experimental.agent_loop.agent_loop import ( AgentLoopManager, - AgentLoopOutput, AgentLoopWorker, AsyncLLMServerManager, - DictConfigWrap, TokenOutput, - _agent_loop_registry, - _get_rollout_and_model_config, - get_trajectory_info, ) from verl.protocol import DataProto from verl.single_controller.ray import RayResourcePool, RayWorkerGroup +from verl.utils.ray_utils import auto_await from verl.utils.rollout_trace import ( - rollout_trace_attr, rollout_trace_op, ) @@ -65,6 +58,8 @@ async def generate( request_id (str): request id for sticky session. prompt_ids (List[int]): List of prompt token ids. sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. + image_data (Optional[List[Any]]): Image data for the chat completion. + video_data (Optional[List[Any]]): Video data for the chat completion. Returns: TokenOutput: token output @@ -107,7 +102,7 @@ async def generate( final_output.stop_reason = output.stop_reason # update model weights version - global_steps = output.extra_info.get("global_steps", None) + global_steps = output.extra_fields.get("global_steps", None) if min_global_steps is None: min_global_steps = global_steps max_global_steps = global_steps @@ -120,49 +115,13 @@ async def generate( break # 4. check stop reason - if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout_resume: + if output.stop_reason not in ("aborted", "abort") or not self.config.async_training.partial_rollout: break - final_output.extra_info["global_steps"] = global_steps - final_output.extra_info["min_global_steps"] = min_global_steps - final_output.extra_info["max_global_steps"] = max_global_steps + final_output.extra_fields["global_steps"] = global_steps + final_output.extra_fields["min_global_steps"] = min_global_steps + final_output.extra_fields["max_global_steps"] = max_global_steps return final_output - @rollout_trace_op - async def generate_for_partial( - self, - request_id, - *, - prompt_ids: list[int], - sampling_params: dict[str, Any], - image_data: Optional[list[Any]] = None, - video_data: Optional[list[Any]] = None, - ) -> tuple[list[Any], list[Any], Any] | tuple[Sequence[int], list[float], bool]: - """Generate tokens from prompt ids, used for async partial. - - Args: - request_id (str): request id for sticky session. - prompt_ids (List[int]): List of prompt token ids. - sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. - - Returns: - output: A tuple representing the generation output. - - Element 0 (Sequence[int]): Generated response token IDs. - - Element 1 (list[float]): Log probabilities for the response token IDs. - - Element 2 (bool): A flag or status indicating cancellation. - """ - server_id, server = await self._acquire_server(request_id=request_id) - try: - output = await server.generate_for_partial.remote( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=sampling_params, - image_data=image_data, - video_data=video_data, - ) - return output - finally: - self._release_server(server_id) - @ray.remote class FullyAsyncAgentLoopWorker(AgentLoopWorker): @@ -173,142 +132,8 @@ def __init__( load_balancer_handle: ray.actor.ActorHandle, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): - self.server_manager = FullyAsyncLLMServerManager( - config, - servers, - load_balancer_handle=load_balancer_handle, - ) - super().__init__( - config, - servers, - load_balancer_handle, - reward_loop_worker_handles, - ) - # A shared cancellation event for all agent loops running on this worker. - self.cancellation_event = asyncio.Event() - - async def generate_sequences_no_post( - self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]] - ) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]: - """Generate sequences from agent loop. - - Args: - batch (DataProto): Input batch. - partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. - - Returns: - list[AgentLoopOutput]: List of agent loop outputs, one per sample in the batch. - """ - config = self.rollout_config - sampling_params = dict( - temperature=config.temperature, - top_p=config.top_p, - repetition_penalty=1.0, - logprobs=config.calculate_log_probs, - ) - - # override sampling params for validation - if batch.meta_info.get("validate", False): - sampling_params["top_p"] = config.val_kwargs.top_p - sampling_params["temperature"] = config.val_kwargs.temperature - - if "agent_name" not in batch.non_tensor_batch: - default_agent_loop = config.agent.default_agent_loop - batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object) - - if "index" in batch.non_tensor_batch: - index = batch.non_tensor_batch["index"] - else: - index = np.arange(len(batch)) - - trajectory_info = await get_trajectory_info( - batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) - ) - - if not partial_output_list: - partial_output_list = [None] * len(batch) - try: - tasks = [] - for i in range(len(batch)): - kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} - kwargs["output"] = partial_output_list[i] - tasks.append( - asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs)) - ) - outputs = await asyncio.gather(*tasks) - except Exception: - logger.exception("_partial_run_agent_loop failed") - raise - - is_cancel = any(output.extra_fields.get("is_cancel", False) for output in outputs) - if not is_cancel: - output = self._postprocess(outputs) - output = self._addition_process(output) - return output, is_cancel - return outputs, is_cancel - - def _addition_process(self, output: DataProto): - """collect metrics""" - metrics = output.meta_info.pop("metrics") # List[Dict[str, str]] - processing_times_list = [item["generate_sequences"] for item in metrics] - tool_calls_times_list = [item["tool_calls"] for item in metrics] - output.non_tensor_batch["processing_times"] = processing_times_list - output.non_tensor_batch["tool_calls_times"] = tool_calls_times_list - return output - - async def _partial_run_agent_loop( - self, - sampling_params: dict[str, Any], - trajectory: dict[str, Any], - *, - agent_name: str, - **kwargs, - ) -> AgentLoopOutput: - # Completed, return directly - if kwargs["output"] is not None and not kwargs["output"].extra_fields.get("is_cancel", False): - logger.info("In _partial_run_agent_loop, already completed, return directly!") - return kwargs["output"] - try: - with rollout_trace_attr( - step=trajectory["step"], - sample_index=trajectory["sample_index"], - rollout_n=trajectory["rollout_n"], - validate=trajectory["validate"], - name="agent_loop", - ): - assert agent_name in _agent_loop_registry, ( - f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" - ) - - agent_loop_config = _agent_loop_registry[agent_name] - agent_loop = hydra.utils.instantiate( - config=agent_loop_config, - trainer_config=DictConfigWrap(config=self.config), - server_manager=self.server_manager, - tokenizer=self.tokenizer, - processor=self.processor, - dataset_cls=self.dataset_cls, - data_config=DictConfigWrap(config=self.config.data), - ) - output: AgentLoopOutput = await agent_loop.run( - sampling_params, cancellation_event=self.cancellation_event, **kwargs - ) - if not output.extra_fields.get("is_cancel", False): - kwargs.pop("output", None) - output = await self._agent_loop_postprocess(output, **kwargs) - - return output - except Exception: - logger.exception("Agent_loop run failed") - raise - - async def cancel_agent_loops(self): - """Set the shared cancellation event to stop all agent loops.""" - self.cancellation_event.set() - - async def resume_agent_loops(self): - """Clear the shared cancellation event.""" - self.cancellation_event.clear() + self.server_manager = FullyAsyncLLMServerManager(config, servers, load_balancer_handle) + super().__init__(config, servers, load_balancer_handle, reward_loop_worker_handles) class FullyAsyncAgentLoopManager(AgentLoopManager): @@ -319,49 +144,20 @@ def __init__( rollout_resource_pool: RayResourcePool = None, reward_loop_worker_handles: list[ray.actor.ActorHandle] = None, ): - self.config = config - self.rollout_config, self.model_config = _get_rollout_and_model_config(config) - self.worker_group = worker_group - self.reward_loop_worker_handles = reward_loop_worker_handles self.agent_loop_workers_class = FullyAsyncAgentLoopWorker + super().__init__(config, worker_group, rollout_resource_pool, reward_loop_worker_handles) - # Select rollout replica class based on rollout name - rollout_name = self.rollout_config.name - if rollout_name == "sglang": - from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica - - self.rollout_replica_class = FullyAsyncSGLangReplica - print("[FullyAsyncAgentLoopManager] SGLang replica class selected") - elif rollout_name == "vllm": - from verl.experimental.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica - - self.rollout_replica_class = FullyAsyncvLLMReplica - print("[FullyAsyncAgentLoopManager] vLLM replica class selected") - else: - raise ValueError(f"Unsupported rollout name: {rollout_name}. Supported values are 'sglang' and 'vllm'.") - - self.rollout_replicas = None - self.server_handles = None - self.server_addresses = None - self.agent_loop_workers = None - - async def generate_single_sample_async( - self, - sample: DataProto, - partial_output_list: Optional[list[AgentLoopOutput]], - ) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]: - """ - Asynchronously process a single sample + @auto_await + async def generate_sequences_single(self, prompts: DataProto) -> DataProto: + """Split input batch and dispatch to agent loop workers. Args: - sample: Single sample data - partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. - + prompts (DataProto): Input batch. Single sample data Returns: - list[AgentLoopOutput]: Processing results + DataProto: Output batch. """ worker = self._select_best_worker() - output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list) + output_future = worker.generate_sequences.remote(prompts) return await asyncio.wrap_future(output_future.future()) def _select_best_worker(self): @@ -372,22 +168,3 @@ def _select_best_worker(self): worker = self.agent_loop_workers[self._worker_index] self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers) return worker - - async def cancel(self): - worker_cancel_tasks = [worker.cancel_agent_loops.remote() for worker in self.agent_loop_workers] - rollout_cancel_tasks = [replica.cancel() for replica in self.rollout_replicas] - await asyncio.gather(*rollout_cancel_tasks, *worker_cancel_tasks) - - async def resume(self): - rollout_resume_tasks = [replica.resume() for replica in self.rollout_replicas] - worker_resume_tasks = [worker.resume_agent_loops.remote() for worker in self.agent_loop_workers] - await asyncio.gather(*rollout_resume_tasks, *worker_resume_tasks) - - async def wake_up(self): - await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas]) - - async def sleep(self): - await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas]) - - async def clear_kv_cache(self): - await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py deleted file mode 100644 index 92ea23c6f2c..00000000000 --- a/verl/experimental/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from verl.experimental.agent_loop import AgentLoopBase -from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register -from verl.utils.profiler import simple_timer -from verl.utils.tokenizer import normalize_token_ids - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -@register("partial_single_turn_agent") -class PartialSingleTurnAgentLoop(AgentLoopBase): - """Naive agent loop that only do single turn chat completion.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.prompt_length = self.rollout_config.prompt_length - self.response_length = self.rollout_config.response_length - self.apply_chat_template_kwargs = self.data_config.get("apply_chat_template_kwargs", {}) - - async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: - output: Optional[AgentLoopOutput] = kwargs.get("output", None) - messages = list(kwargs["raw_prompt"]) - multi_modal_data = await self.process_vision_info(messages) - images = multi_modal_data.get("images") - videos = multi_modal_data.get("videos") - - param_version = kwargs.get("param_version", 0) - - metrics = {} - request_id = uuid4().hex - - param_version_start = param_version - param_version_end = param_version - - if not output: - # TODO(baiyan): it is supposed to use the correct processor, - # but I found the async training would hang if use_correct_processor=True. - # so we use the tokenizer to tokenize the prompt for now. - use_correct_processor = False - if self.processor is not None and use_correct_processor: - - def get_prompt_ids(): - raw_prompt = self.processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=False, - **self.apply_chat_template_kwargs, - ) - model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") - return model_inputs.pop("input_ids").squeeze(0).tolist() - - prompt_ids = await self.loop.run_in_executor(None, get_prompt_ids) - # Refer to the implementation of the run function in verl/experimental/agent_loop/single_turn_agent_loop.py - elif self.processor is not None: - prompt_ids = await self.apply_chat_template( - messages, - images=images, - videos=videos, - ) - else: - tokenized_prompt = await self.loop.run_in_executor( - None, - lambda: self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs - ), - ) - prompt_ids = normalize_token_ids(tokenized_prompt) - else: - if output.extra_fields.get("is_cancel", False): - # Resume the paused sample, - # add the result directly after prompt_ids, - # and reset generate_sequences metric - prompt_ids = output.prompt_ids + output.response_ids - metrics["generate_sequences"] = output.metrics.generate_sequences - param_version_start = output.extra_fields.get("param_version_start", param_version) - else: - # In the same batch of samples, - # some are canceled and some are not. - # The samples without partial rollout are returned directly. - return output - with simple_timer("generate_sequences", metrics): - response_ids, response_logprobs, is_cancel = await self.server_manager.generate_for_partial( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=sampling_params, - image_data=images, - video_data=videos, - ) - if not output: - response_mask = [1] * len(response_ids) - else: - # Pause the sample to be resumed, add the output result to response_ids, and reset response_mask - prompt_ids = output.prompt_ids - response_logprobs = output.response_logprobs + response_logprobs - response_ids = output.response_ids + response_ids - response_mask = [1] * len(response_ids) - if len(response_ids) >= self.response_length: - is_cancel = False - - return AgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids[: self.response_length], - response_mask=response_mask[: self.response_length], - response_logprobs=response_logprobs[: self.response_length], - num_turns=2, - metrics=metrics, - extra_fields={ - "is_cancel": is_cancel, - "param_version_start": param_version_start, - "param_version_end": param_version_end, - "turn_scores": [], - "tool_rewards": [], - }, - multi_modal_data=multi_modal_data, - # multi_modal_data={"image": image_data} if image_data is not None else {}, - ) diff --git a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py deleted file mode 100644 index 370587f0364..00000000000 --- a/verl/experimental/fully_async_policy/agent_loop/partial_tool_agent_loop.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register -from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop -from verl.utils.profiler import simple_timer - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -@register("async_partial_tool_agent") -class AsyncPartialToolAgentLoop(ToolAgentLoop): - """ - Support for partial rollout with multiple tool invocations in Agent Loop - - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.enable_partial_rollout = self.config.async_training.get("partial_rollout", False) - - # async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: - async def run( - self, sampling_params: dict[str, Any], *, cancellation_event: asyncio.Event = None, **kwargs - ) -> AgentLoopOutput: - """ - Main entrance, supports interruption/recovery - - Args: - sampling_params: Sampling parameters - cancellation_event: cancellationn sginal - **kwargs: Contains output (for recovery), raw_prompt, param_version, etc. - - Returns: - AgentLoopOutput: Include the is_cancel flag - """ - param_version = kwargs.get("param_version", 0) - agent_data = None - state = None - - # 1. check whether is the partial task - output: Optional[AgentLoopOutput] = kwargs.get("output", None) - if output and output.extra_fields.get("is_cancel", False): - agent_data, state = self._restore_from_output(output) - - logger.info(f"[PartialToolAgent] Resuming from {state.value}") - else: - if output and not output.extra_fields.get("is_cancel", False): - # Completed, return directly - return output - - agent_data = await self._init_agent_data(kwargs, param_version) - state = AgentState.PENDING - logger.info("[PartialToolAgent] Start from scratch") - # 2. run state machine - state = await self._run_state_machine(agent_data, state, sampling_params, cancellation_event) - - # 3. bulid output - if state == AgentState.TERMINATED: - return self._build_completed_output(agent_data, param_version) - else: - # build cancelled output - return self._build_cancelled_output(agent_data, state) - - async def _init_agent_data(self, kwargs: dict, param_version: int) -> AgentData: - messages = list(kwargs["raw_prompt"]) - multi_modal_data = await self.process_vision_info(messages) - image_data = multi_modal_data.get("images") - video_data = multi_modal_data.get("videos") - metrics = {} - request_id = uuid4().hex - tools_kwargs = kwargs.get("tools_kwargs", {}) - - # Initialize interaction if needed - interaction = None - interaction_kwargs = {} - if self.interaction_config_file: - interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"] - if "name" not in interaction_kwargs: - raise ValueError("'name' key is required in interaction_kwargs") - interaction_name = interaction_kwargs["name"] - if interaction_name not in self.interaction_map: - raise ValueError( - f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " - f"{list(self.interaction_map.keys())}" - ) - interaction = self.interaction_map[interaction_name] - await interaction.start_interaction(request_id, **interaction_kwargs) - # Create AgentData instance to encapsulate all state - agent_data = AgentData( - messages=messages, - image_data=image_data, - video_data=video_data, - metrics=metrics, - request_id=request_id, - tools_kwargs=tools_kwargs, - interaction=interaction, - interaction_kwargs=interaction_kwargs, - ) - - # additional param version record - agent_data.extra_fields["param_version_start"] = param_version - agent_data.extra_fields["param_version_end"] = param_version - - return agent_data - - def _restore_from_output(self, output: AgentLoopOutput) -> tuple[AgentData, AgentState]: - """restore AgentState and AgentData from output""" - agent_data = output.extra_fields.get("agent_data", None) - agent_state = output.extra_fields.get("agent_state", None) - if agent_data is None or agent_state is None: - raise ValueError(f"Unexpected situation: agent_data is {agent_data}, agent_state is {agent_state}") - return agent_data, agent_state - - async def _run_state_machine( - self, - agent_data: AgentData, - state: AgentState, - sampling_params: dict[str, Any], - cancellation_event: asyncio.Event = None, - ) -> AgentState: - """ - State machine. - Currently, interruptions are only supported to occur in the GENERATING state or other states have ended. - """ - # State machine loop - while state != AgentState.TERMINATED: - if cancellation_event and cancellation_event.is_set(): - logger.info(f"[PartialToolAgent] Cancellation detected. Interrupted before/at state: {state.value}") - return state - if state == AgentState.PENDING: - state = await self._handle_pending_state(agent_data, sampling_params) - elif state == AgentState.GENERATING: - state = await self._handle_generating_state_partial(agent_data, sampling_params) - elif state == AgentState.PROCESSING_TOOLS: - state = await self._handle_processing_tools_state(agent_data) - elif state == AgentState.INTERACTING: - state = await self._handle_interacting_state(agent_data) - else: - logger.error(f"[PartialToolAgent] Invalid state: {state}") - return AgentState.TERMINATED - - return AgentState.TERMINATED - - async def _handle_generating_state_partial( - self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False - ) -> AgentState: - """ - Handle GENERATING state, support partial rollout - """ - add_messages: list[dict[str, Any]] = [] - - with simple_timer("generate_sequences", agent_data.metrics): - # partial interface - if self.enable_partial_rollout: - response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial( - request_id=agent_data.request_id, - prompt_ids=agent_data.prompt_ids, - sampling_params=sampling_params, - image_data=agent_data.image_data, - video_data=agent_data.video_data, - ) - - if is_cancel: - # Save the generated parts - agent_data.response_ids = response_ids - agent_data.prompt_ids += agent_data.response_ids - agent_data.response_mask += [1] * len(response_ids) - if log_probs: - agent_data.response_logprobs += log_probs - if not ignore_termination and len(agent_data.response_mask) >= self.response_length: - # If response_length has reached the limit, - # it is considered to have ended normally. - agent_data.assistant_turns += 1 - return AgentState.TERMINATED - return AgentState.GENERATING - else: - # original generate interface - output = await self.server_manager.generate( - request_id=agent_data.request_id, - prompt_ids=agent_data.prompt_ids, - sampling_params=sampling_params, - image_data=agent_data.image_data, - video_data=agent_data.video_data, - ) - response_ids = output.token_ids - log_probs = output.log_probs - - agent_data.assistant_turns += 1 - agent_data.response_ids = response_ids - agent_data.prompt_ids += agent_data.response_ids - agent_data.response_mask += [1] * len(agent_data.response_ids) - if log_probs: - agent_data.response_logprobs += log_probs - - if not ignore_termination and len(agent_data.response_mask) >= self.response_length: - return AgentState.TERMINATED - if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns: - return AgentState.TERMINATED - if self.max_user_turns and agent_data.user_turns >= self.max_user_turns: - return AgentState.TERMINATED - - # Extract tool calls - _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids) - - # Handle interaction if needed - if self.interaction_config_file: - assistant_message = await self.loop.run_in_executor( - None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True) - ) - add_messages.append({"role": "assistant", "content": assistant_message}) - agent_data.messages.extend(add_messages) - - # Determine next state - if agent_data.tool_calls: - return AgentState.PROCESSING_TOOLS - elif self.interaction_config_file: - return AgentState.INTERACTING - else: - return AgentState.TERMINATED - - def _build_completed_output(self, agent_data: AgentData, param_version: int) -> AgentLoopOutput: - """build completed output""" - response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] - prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)] - multi_modal_data = {} - if agent_data.image_data is not None: - multi_modal_data["image"] = agent_data.image_data - if agent_data.video_data is not None: - multi_modal_data["video"] = agent_data.video_data - - output = AgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids[: self.response_length], - response_mask=agent_data.response_mask[: self.response_length], - multi_modal_data=multi_modal_data, - response_logprobs=agent_data.response_logprobs[: self.response_length] - if agent_data.response_logprobs - else None, - num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, - metrics=agent_data.metrics, - extra_fields={}, - ) - output.extra_fields.update( - { - "turn_scores": agent_data.turn_scores, - "tool_rewards": agent_data.tool_rewards, - "is_cancel": False, - "param_version_start": agent_data.extra_fields["param_version_start"], - "param_version_end": param_version, - } - ) - return output - - def _build_cancelled_output(self, agent_data: AgentData, state: AgentState) -> AgentLoopOutput: - """build cancelled output""" - return AgentLoopOutput( - prompt_ids=[], - response_ids=[], - response_mask=[], - multi_modal_data={}, - response_logprobs=None, - num_turns=0, - metrics=agent_data.metrics, - extra_fields={ - "is_cancel": True, - "agent_data": agent_data, - "agent_state": state, - }, - ) diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml index d1c753864f9..9acc742817e 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml @@ -21,17 +21,13 @@ async_training: # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once require_batches: 1 - # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout + # When synchronizing parameters, Whether to resume generation when rollout is interrupted. + # If True, AsyncLLMServerManager auto resume generation, making rollout interruption invisible to the AgentLoop. partial_rollout: True - # Whether to resume generation when rollout is interrupted. If True, AsyncLLMServerManager - # auto resume generation, making rollout interruption invisible to the AgentLoop. - partial_rollout_resume: True - # whether to use trainer do_validate use_trainer_do_validate: False - # Rollout config rollout: @@ -47,12 +43,6 @@ rollout: # total rollout samples # TODO rename to total_rollout_samples total_rollout_steps: 100 - # Number of epochs in training - total_epochs: 10 - - # Test frequency, how many times a parameter update triggers a validation - test_freq: 1 - data: # Number of samples generated, currently only support 1 gen_batch_size: 1 @@ -60,8 +50,6 @@ data: actor_rollout_ref: rollout: - # Must be turned off! Otherwise, Parameter synchronization cannot be performed. - free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml index 76e446ce166..f0753d969ee 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -21,13 +21,10 @@ async_training: # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once require_batches: 1 - # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout + # When synchronizing parameters, Whether to resume generation when rollout is interrupted. + # If True, AsyncLLMServerManager auto resume generation, making rollout interruption invisible to the AgentLoop. partial_rollout: True - # Whether to resume generation when rollout is interrupted. If True, AsyncLLMServerManager - # auto resume generation, making rollout interruption invisible to the AgentLoop. - partial_rollout_resume: True - # whether to use trainer do_validate use_trainer_do_validate: False @@ -46,12 +43,6 @@ rollout: # total rollout samples # TODO rename to total_rollout_samples total_rollout_steps: 100 - # Number of epochs in training - total_epochs: 10 - - # Test frequency, how many times a parameter update triggers a validation - test_freq: 1 - data: # Number of samples generated, currently only support 1 gen_batch_size: 1 @@ -59,8 +50,6 @@ data: actor_rollout_ref: rollout: - # Must be turned off! Otherwise, Parameter synchronization cannot be performed. - free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True diff --git a/verl/experimental/fully_async_policy/detach_utils.py b/verl/experimental/fully_async_policy/detach_utils.py index c8d2c02ebca..47c9d4d05c1 100644 --- a/verl/experimental/fully_async_policy/detach_utils.py +++ b/verl/experimental/fully_async_policy/detach_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import time from collections import defaultdict from dataclasses import dataclass @@ -20,7 +21,6 @@ import torch from verl import DataProto -from verl.experimental.agent_loop.agent_loop import AgentLoopOutput from verl.trainer.ppo.ray_trainer import compute_response_mask @@ -31,19 +31,11 @@ class RolloutSample: # Original batch information full_batch: Any - # AgentLoopOutput from generation - agent_loop_output_list: list[AgentLoopOutput] - # Metadata sample_id: str epoch: int # Processing metadata - processing_times: list[float] - tool_calls: list[float] - param_version: int - param_version_start: list[int] - param_version_end: list[int] rollout_status: dict[str, Any] @@ -53,8 +45,6 @@ class ValidateMetrics: timing_raw: dict[str, Any] metrics: Optional[dict[str, Any]] = None - global_steps: Optional[int] = None - param_version: Optional[int] = None def prepare_single_generation_data(batch_dict, config) -> DataProto: @@ -82,19 +72,25 @@ def prepare_single_generation_data(batch_dict, config) -> DataProto: # Setting selected agent, that supports partial if config.actor_rollout_ref.rollout.multi_turn.enable: - full_batch.non_tensor_batch["agent_name"] = np.array( - ["async_partial_tool_agent"] * len(full_batch), dtype=object - ) + full_batch.non_tensor_batch["agent_name"] = np.array(["tool_agent"] * len(full_batch), dtype=object) else: - full_batch.non_tensor_batch["agent_name"] = np.array( - ["partial_single_turn_agent"] * len(full_batch), dtype=object - ) + full_batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(full_batch), dtype=object) # Add global step count to generated data full_batch = full_batch.repeat(repeat_times=config.actor_rollout_ref.rollout.n, interleave=True) return full_batch +def addition_process(output: DataProto): + """collect metirics""" + metrics = output.meta_info.pop("metrics") # List[Dict[str, str]] + processing_times_list = [item["generate_sequences"] for item in metrics] + tool_calls_times_list = [item["tool_calls"] for item in metrics] + output.non_tensor_batch["processing_times"] = processing_times_list + output.non_tensor_batch["tool_calls_times"] = tool_calls_times_list + return output + + def assemble_batch_from_rollout_samples( rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None ) -> DataProto: @@ -122,14 +118,13 @@ def assemble_batch_from_rollout_samples( print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects") rollout_samples_batch = [] - processing_times = [] - tool_calls = [] rollout_status = rollout_samples[0].rollout_status # Add a prefix to all rollout_status keys rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()} for rs in rollout_samples: - rollout_samples_batch.append(rs.full_batch) + batch = addition_process(rs.full_batch) + rollout_samples_batch.append(batch) final_batch = DataProto.concat(rollout_samples_batch) # Calculate response_mask (if not present) @@ -146,7 +141,6 @@ def assemble_batch_from_rollout_samples( processing_times = final_batch.non_tensor_batch["processing_times"] tool_calls = final_batch.non_tensor_batch["tool_calls_times"] # Collect statistics - processing_time_stats = { "processing_time/avg": np.mean(processing_times), "processing_time/max": np.max(processing_times), @@ -164,8 +158,8 @@ def assemble_batch_from_rollout_samples( } processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()} - param_version_start = final_batch.non_tensor_batch["param_version_start"] - param_version_end = final_batch.non_tensor_batch["param_version_end"] + param_version_start = final_batch.non_tensor_batch["min_global_steps"] + param_version_end = final_batch.non_tensor_batch["max_global_steps"] param_version_diff = [abs(a - b) for a, b in zip(param_version_end, param_version_start, strict=False)] num_diff0 = param_version_diff.count(0) partial_stats = { @@ -174,14 +168,12 @@ def assemble_batch_from_rollout_samples( "fully_async/partial/max_partial_span": max(param_version_diff), } # add meta_info - param_versions = [rs.param_version for rs in rollout_samples] - trajectorys_param_versions = final_batch.non_tensor_batch["param_version_end"] + trajectory_param_versions = final_batch.non_tensor_batch["max_global_steps"] final_batch.meta_info.update( { - "rollout_param_versions": param_versions, - "param_version_diversity": len(set(param_versions)) if param_versions else 0, - "trajectory_param_versions": trajectorys_param_versions, + "param_version_diversity": len(set(trajectory_param_versions)), + "trajectory_param_versions": trajectory_param_versions, **processing_time_stats, **rollout_status, **partial_stats, @@ -322,7 +314,7 @@ def get_aggregated_metrics(self) -> dict[str, Any]: # Aggregate special metrics aggregated = self._special_metrics_aggergate(aggregated) - print(f"aggregated metrics done. cost {time.time() - t}") + print(f"aggregated metrics done. cost {time.time() - t:.4f} seconds.") return aggregated @@ -342,7 +334,7 @@ def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, An # trainer/idle_ratio if "timing_s/gen" in aggregated.keys() and "timing_s/step" in aggregated.keys(): - aggregated["trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"] + aggregated["fully_async/trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"] return aggregated @@ -361,3 +353,32 @@ def get_current_stats(self) -> dict[str, Any]: "total_samples": sum(self.sample_counts), "metric_names": list(self.metric_values.keys()), } + + +def task_exception_handler(task: asyncio.Task): + """Handle task exceptions and log them""" + try: + task.result() + except asyncio.CancelledError: + pass # Task was cancelled, this is expected + except Exception as e: + print(f"Task {task.get_name()} failed with exception: {e}") + raise e + + +def safe_create_task(coro, name: str, task_set: set = None): + """Safely create a task with exception handling + + Args: + coro: The coroutine to run + name: Name for the task + task_set: Optional set to add the task to + + Returns: + The created asyncio.Task + """ + task = asyncio.create_task(coro, name=name) + task.add_done_callback(task_exception_handler) + if task_set is not None: + task_set.add(task) + return task diff --git a/verl/experimental/fully_async_policy/fsdp2_utils.py b/verl/experimental/fully_async_policy/fsdp2_utils.py deleted file mode 100644 index 1f1856596fb..00000000000 --- a/verl/experimental/fully_async_policy/fsdp2_utils.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -import torch.distributed as dist -from packaging import version -from torch.distributed.tensor import DTensor -from torch.distributed.tensor._dtensor_spec import DTensorSpec - -if version.parse(torch.__version__) < version.parse("2.6"): - raise RuntimeError("PyTorch 2.6 or higher is required to use fstp_utils.") - - -def fsdp2_sharded_save_to_cpu( - model: torch.nn.Module, -) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: - """ - Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. - - Args: - model: FSDP2-wrapped model whose parameters are of DTensor type. - - Returns: - cpu_sharded_state: Dictionary of CPU shards for the current process. - Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) - global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) - """ - cpu_sharded_state = {} - global_spec = None # Record global sharding rules (all parameters follow the same spec) - - for param_name, param in model.named_parameters(): - # Only process sharded parameters of DTensor type (core parameters of FSDP2) - if not isinstance(param, DTensor): - # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data - cpu_tensor = param.detach().cpu() - cpu_sharded_state[param_name] = (cpu_tensor, None) - continue - - # Record global sharding rules (take spec of the first DTensor to ensure consistency) - if global_spec is None: - global_spec = param._spec - assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" - assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" - - # 1. Extract local shard data from the current GPU (_local_tensor) - local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class - # 2. Move to CPU memory and detach from computation graph - local_cpu_tensor = local_gpu_tensor.detach().cpu() - # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged) - cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) - - assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." - return cpu_sharded_state, global_spec - - -def fsdp2_sharded_load_from_cpu( - model: torch.nn.Module, - cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], - target_spec: DTensorSpec, -) -> None: - """ - Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, - keeping sharding rules unchanged. - - Args: - model: FSDP2 model to be restored (must have the same structure as when saved) - cpu_sharded_state: Shard data read from CPU memory by the current process - (from fsdp2_sharded_save_to_cpu) - target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) - """ - # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs) - current_device_mesh = None - for param in model.parameters(): - if isinstance(param, DTensor): - current_device_mesh = param._spec.device_mesh - break - assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" - assert current_device_mesh == target_spec.device_mesh, ( - f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" - ) - - for param_name, param in model.named_parameters(): - # Skip parameters not in the saved state (e.g., newly added parameters) - if param_name not in cpu_sharded_state: - continue - - # Extract CPU shard data and original Spec - local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] - - # Handle different parameter types: DTensor sharded parameters vs. regular parameters - if isinstance(param, DTensor): - # 1. Verify sharding rule consistency (placements must match original Spec) - assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" - assert saved_spec.placements == target_spec.placements, ( - f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" - ) - - # 2. Move CPU shard data to the current GPU (device of param._local_tensor) - target_device = param._local_tensor.device - local_gpu_tensor = local_cpu_tensor.to(target_device) - - # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged) - param._local_tensor.copy_(local_gpu_tensor) - - else: - # Regular parameters: load directly to original device - target_device = param.device - param.data.copy_(local_cpu_tensor.to(target_device)) - - # Process synchronization: ensure all processes complete loading before proceeding - dist.barrier() diff --git a/verl/experimental/fully_async_policy/fully_async_main.py b/verl/experimental/fully_async_policy/fully_async_main.py index 4e9e509475f..18b2eddb056 100644 --- a/verl/experimental/fully_async_policy/fully_async_main.py +++ b/verl/experimental/fully_async_policy/fully_async_main.py @@ -25,86 +25,11 @@ from verl.experimental.fully_async_policy.fully_async_rollouter import FullyAsyncRollouter from verl.experimental.fully_async_policy.fully_async_trainer import FullyAsyncTrainer from verl.experimental.fully_async_policy.message_queue import MessageQueue, MessageQueueClient -from verl.trainer.ppo.ray_trainer import ResourcePoolManager -from verl.trainer.ppo.utils import Role, need_reference_policy +from verl.experimental.separation.utils import create_resource_pool_manager, create_role_worker_mapping +from verl.trainer.ppo.utils import Role from verl.utils.fs import copy_to_local -def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: - """ - Create resource pool manager - - Args: - config: Configuration object - roles: List of roles that need to create resource pools - - Returns: - ResourcePoolManager: Resource pool manager - """ - resource_pool_spec = {} - mapping = {} - - # Actor/Critic resource pool - if any(role in roles for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]): - assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" - assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" - - trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes - resource_pool_spec["trainer_pool"] = trainer_pool - - # Map training-related roles to the same resource pool - for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]: - if role in roles: - mapping[role] = "trainer_pool" - - # Rollout resource pool - if Role.Rollout in roles: - assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" - assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" - - rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes - resource_pool_spec["rollout_pool"] = rollout_pool - mapping[Role.Rollout] = "rollout_pool" - - return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - -def create_role_worker_mapping(config): - """ - Create mapping from roles to worker classes - - Args: - config: Configuration object - - Returns: - dict: Mapping from roles to worker classes - """ - # Select worker class based on strategy - use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") - if use_legacy_worker_impl == "disable": - from verl.experimental.separation.engine_workers import DetachActorWorker - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.engine_workers import TrainingWorker - - ray_worker_group_cls = RayWorkerGroup - - CriticWorker = TrainingWorker - else: - raise NotImplementedError("Fully async policy does not support legacy worker implementation") - - train_role = Role.ActorRollout if config.async_training.use_trainer_do_validate else Role.Actor - role_worker_mapping = { - train_role: ray.remote(DetachActorWorker), - Role.Critic: ray.remote(CriticWorker), - } - - # Add reference policy (if KL loss or reward is required) - if need_reference_policy(config): - role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) - - return role_worker_mapping, ray_worker_group_cls - - @ray.remote(num_cpus=1) class FullyAsyncTaskRunner: """ @@ -151,7 +76,6 @@ def _initialize_components(self, config) -> None: print("[ASYNC MAIN] Creating FullyAsyncRollouter and FullyAsyncTrainer in parallel...") with ThreadPoolExecutor(max_workers=2) as executor: - # TODO: keep _create_rollouter and _create_trainer parallel # Rollouter does not permit continuous allocation, so we allocate trainer first. trainer_future = executor.submit(self._create_trainer, config) trainer_future.result() @@ -175,35 +99,23 @@ def _initialize_components(self, config) -> None: ray.get(self.components["rollouter"].set_message_queue_client.remote(self.components["message_queue_client"])) ray.get(self.components["trainer"].set_message_queue_client.remote(self.components["message_queue_client"])) + # param_version resume from ckpt or default 0 + ray.get(self.components["trainer"].load_checkpoint.remote()) + ray.get(self.components["rollouter"].load_checkpoint.remote()) + print("[ASYNC MAIN] Setting up parameter synchronization...") - from verl.experimental.fully_async_policy.param_sync import ParameterSynchronizer + ray.get(self.components["trainer"].set_rollouter.remote(self.components["rollouter"])) - param_synchronizer = ParameterSynchronizer.remote( - config=config, - trainer=self.components["trainer"], - rollouter=self.components["rollouter"], - mq=self.components["message_queue_client"], - ) - ray.get(self.components["trainer"].set_parameter_synchronizer.remote(param_synchronizer)) + print("[ASYNC MAIN] Param sync before fit..") + ray.get(self.components["trainer"]._fit_update_weights.remote()) - # load checkpoint and sync parameter before doing anything - val_before_train = config.trainer.get("val_before_train", True) - # param_version resume from ckpt or default 0 - param_version = ray.get(self.components["trainer"].load_checkpoint.remote()) - ray.get(self.components["rollouter"].load_checkpoint.remote()) - ray.get( - param_synchronizer.sync_weights.remote( - version=param_version, - validate=val_before_train, - use_trainer_do_validate=config.async_training.use_trainer_do_validate, - ) - ) - ray.get(param_synchronizer.wait_last_valid.remote()) + if config.trainer.get("val_before_train", True): + ray.get(self.components["trainer"]._fit_validate.remote(True)) - self.components["param_synchronizer"] = param_synchronizer print("[ASYNC MAIN] All components initialized successfully") def _create_rollouter(self, config) -> None: + print("[ASYNC MAIN] Starting create rollouter...") rollouter = FullyAsyncRollouter.remote( config=config, tokenizer=self.components["tokenizer"], @@ -221,6 +133,7 @@ def _create_rollouter(self, config) -> None: print("[ASYNC MAIN] Rollouter created and initialized successfully") def _create_trainer(self, config) -> None: + print("[ASYNC MAIN] Starting create trainer...") trainer_role_mapping = { role: worker_cls for role, worker_cls in self.components["role_worker_mapping"].items() @@ -284,6 +197,9 @@ def main(config): # Ensure async training config exists if not hasattr(config, "async_training"): raise RuntimeError("must set async_training config") + + assert config.async_training.use_trainer_do_validate is False, "use_trainer_do_validate is not ready to use." + from time import time start_time = time() diff --git a/verl/experimental/fully_async_policy/fully_async_rollouter.py b/verl/experimental/fully_async_policy/fully_async_rollouter.py index ce21be271cf..fdde54a67c4 100644 --- a/verl/experimental/fully_async_policy/fully_async_rollouter.py +++ b/verl/experimental/fully_async_policy/fully_async_rollouter.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import functools import multiprocessing import os import time @@ -23,12 +22,12 @@ import numpy as np import ray import torch -from ray import ObjectRef from verl.experimental.fully_async_policy.detach_utils import ( RolloutSample, ValidateMetrics, prepare_single_generation_data, + safe_create_task, ) from verl.experimental.fully_async_policy.message_queue import MessageQueueClient from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer @@ -69,7 +68,7 @@ def __init__( assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one" assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0" assert self.config.async_training.trigger_parameter_sync_step >= 1, ( - "trigger_parameter_sync_step must larger than 1" + "trigger_parameter_sync_step must larger or equal than 1" ) self.role_worker_mapping = role_worker_mapping @@ -155,21 +154,19 @@ def __init__( self.max_queue_size = None # Statistics - self.current_param_version = 0 self.total_generated_samples = 0 self.staleness_samples = 0 self.dropped_stale_samples = 0 self.processed_sample_count = 0 # we start from step 1 self.global_steps = 1 - self.idle_start_time = None - self.version_start_time = None + self.idle_start_time = time.time() + self.step_start_time = time.time() # Concurrency control # Modified by self.pause() or self._should_pause_generation() self.paused = False self.running = True - self.monitor_loop_trigger = True # Add dataloader lock self.dataloader_lock = asyncio.Lock() @@ -177,12 +174,10 @@ def __init__( # Initialize async queues self.pending_queue = asyncio.Queue(maxsize=128) self.active_tasks = set() - self.cancel_queue = asyncio.Queue() cpu_cores = multiprocessing.cpu_count() # cpu case use cpu_cores; io case use cpu_cores*2 self.validate_executor = ThreadPoolExecutor(max_workers=cpu_cores) - self.parallel_validate_and_rollout = config.async_training.get("parallel_validate_and_rollout", False) self.validate_task = None def _init_async_objects(self): @@ -239,96 +234,38 @@ def get_max_queue_size(self): def get_total_train_steps(self): return self.total_train_steps - async def update_param_version( - self, version: int, validate: bool = False, global_steps: int = 0, use_trainer_do_validate: bool = False - ): - """Update current parameter version""" + async def reset_staleness(self): + """ + Reset staleness samples after parameter update. + Returns timing_raw dictionary for metrics. + """ async with self.lock: - old_version = self.current_param_version - self.current_param_version = version + self.paused = False + self.condition.notify_all() # every time param change, reset staleness_samples - self.staleness_samples = ( - len(self.active_tasks) + self.cancel_queue.qsize() + await self.message_queue_client.get_queue_size() - ) + self.staleness_samples = len(self.active_tasks) + await self.message_queue_client.get_queue_size() timing_raw = {} - idle_ratio = None - if self.idle_start_time is not None and self.version_start_time is not None: - rollout_active_time = self.idle_start_time - self.version_start_time - rollout_version_time = time.time() - self.version_start_time - idle_ratio = 1 - rollout_active_time / rollout_version_time - timing_raw["rollouter/active_time"] = rollout_active_time - timing_raw["rollouter/version_time"] = rollout_version_time - timing_raw["rollouter/idle_ratio"] = idle_ratio - self.idle_start_time = None - print( - f"[FullyAsyncRollouter][Public][update_param_version] " - f"Parameter version updated from {old_version} to {version} " - f",reset staleness_samples to: {self.staleness_samples}" - f",idle_ratio: {idle_ratio}" - ) - need_validate = ( - ( - self.config.rollout.test_freq > 0 - and self.current_param_version % self.config.rollout.test_freq == 0 - and self.current_param_version > 0 - ) # don't test here in the initial parameter sync - or validate - ) - print( - f"[FullyAsyncRollouter] need_validate: {need_validate}, " - f"parallel_validate_and_rollout: {self.parallel_validate_and_rollout}" - ) - if not need_validate: - data = ValidateMetrics( - timing_raw=timing_raw, metrics=None, global_steps=global_steps, param_version=version - ) - elif need_validate and not self.parallel_validate_and_rollout: - data = self._validate_wrapper(timing_raw, version, global_steps, use_trainer_do_validate) - - if not need_validate or not self.parallel_validate_and_rollout: - await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data)) - - self.version_start_time = time.time() + rollout_active_time = self.idle_start_time - self.step_start_time + rollout_version_time = time.time() - self.step_start_time + idle_ratio = 1 - rollout_active_time / rollout_version_time + timing_raw["fully_async/rollouter/active_time"] = rollout_active_time + timing_raw["fully_async/rollouter/version_time"] = rollout_version_time + timing_raw["fully_async/rollouter/idle_ratio"] = idle_ratio - if need_validate and self.parallel_validate_and_rollout: - if self.validate_task and not self.validate_task.done(): - print("[FullyAsyncRollouter] validate_task is running, wait last validate_task to finish") - self.validate_task.get() - self.validate_task = asyncio.create_task( - self.do_validate_async(timing_raw, version, global_steps, use_trainer_do_validate) + print( + f"[FullyAsyncRollouter][Public][reset_staleness] " + f"reset staleness_samples to: {self.staleness_samples} " + f"idle_ratio: {timing_raw['fully_async/rollouter/idle_ratio']:.4f}" ) + self.step_start_time = time.time() + return timing_raw - def _validate_wrapper( - self, timing_raw: dict, version: int, global_steps: int = 0, use_trainer_do_validate: bool = False - ): - val_metrics = None + def do_validate(self) -> ValidateMetrics: + """Run validation and return metrics""" + timing_raw = {} with marked_timer("rollouter/validate_time", timing_raw, color="green"): - val_metrics: dict = self._validate(use_trainer_do_validate) - data = ValidateMetrics( - timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version - ) - return data - - async def do_validate_async( - self, - timing_raw: dict, - version: int, - global_steps: int = 0, - use_trainer_do_validate: bool = False, - ): - loop = asyncio.get_running_loop() - - data = await loop.run_in_executor( - self.validate_executor, - functools.partial( - self._validate_wrapper, - timing_raw=timing_raw, - version=version, - global_steps=global_steps, - use_trainer_do_validate=use_trainer_do_validate, - ), - ) - await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data)) + val_metrics: dict = self._validate() + return ValidateMetrics(timing_raw=timing_raw, metrics=val_metrics) async def save_checkpoint(self, local_global_step_folder: str): # WARNING!: Due to the asynchronous nature, there are some in-flight samples @@ -435,7 +372,7 @@ def _create_continuous_iterator(self): """ Create a continuous data iterator across epoch """ - for epoch in range(self.config.rollout.total_epochs): + for epoch in range(self.config.trainer.total_epochs): iterator = iter(self.train_dataloader) for batch_dict in iterator: yield epoch, batch_dict @@ -471,14 +408,8 @@ async def _feed_samples(self): rollout_sample = RolloutSample( full_batch=full_batch, - agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n, sample_id=sample_id, epoch=epoch, - param_version=0, - param_version_start=[], - param_version_end=[], - processing_times=[], - tool_calls=[], rollout_status={}, ) @@ -496,7 +427,7 @@ async def _feed_samples(self): self.global_steps += 1 # End signal - await self.pending_queue.put("DONE") + await self.pending_queue.put(None) print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added") async def _processor_worker(self): @@ -517,24 +448,20 @@ async def _processor_worker(self): done_tasks, self.active_tasks = await asyncio.wait( self.active_tasks, return_when=asyncio.FIRST_COMPLETED ) - for task in done_tasks: - await task + for task in done_tasks: + await task async with self.lock: while self.paused: self.idle_start_time = time.time() await self.condition.wait() continue + # Get sample from appropriate queue and immediately mark task as done + rollout_sample = await self.pending_queue.get() + self.pending_queue.task_done() + self.staleness_samples += 1 - simple_from_cancel_queue = False - if not self.cancel_queue.empty(): - rollout_sample = await self.cancel_queue.get() - simple_from_cancel_queue = True - else: - rollout_sample = await self.pending_queue.get() - self.staleness_samples += 1 - - if rollout_sample == "DONE": + if rollout_sample is None: print( "[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..." ) @@ -544,8 +471,8 @@ async def _processor_worker(self): done_tasks, self.active_tasks = await asyncio.wait( self.active_tasks, return_when=asyncio.FIRST_COMPLETED ) - for task in done_tasks: - await task + for task in done_tasks: + await task break # Check whether the number of concurrent tasks exceeds the limit @@ -555,8 +482,8 @@ async def _processor_worker(self): done_tasks, self.active_tasks = await asyncio.wait( self.active_tasks, return_when=asyncio.FIRST_COMPLETED ) - for task in done_tasks: - await task + for task in done_tasks: + await task # Submit single sample processing async with self.lock: @@ -564,47 +491,29 @@ async def _processor_worker(self): # to determine whether it is the pause phase, otherwise continue to wait while self.paused: await self.condition.wait() - task = asyncio.create_task( + task = safe_create_task( self._process_single_sample_streaming(rollout_sample), name=rollout_sample.sample_id, + task_set=self.active_tasks, ) - self.active_tasks.add(task) - - if simple_from_cancel_queue: - self.cancel_queue.task_done() - else: - self.pending_queue.task_done() async def _process_single_sample_streaming(self, rollout_sample: RolloutSample): """Process a single sample streamingly""" # Calling asynchronous generation methods - rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len( - rollout_sample.full_batch - ) - ret, is_cancel = await self.async_rollout_manager.generate_single_sample_async( - rollout_sample.full_batch, rollout_sample.agent_loop_output_list + ret = await self.async_rollout_manager.generate_sequences_single(rollout_sample.full_batch) + rollout_sample.full_batch = ret + rollout_sample.full_batch.non_tensor_batch["uid"] = np.array( + [f"uid_{rollout_sample.sample_id}"] * len(rollout_sample.full_batch), dtype=object ) - if not is_cancel: - rollout_sample.full_batch = ret - rollout_sample.full_batch.non_tensor_batch["uid"] = np.array( - [f"uid_{rollout_sample.sample_id}"] * len(rollout_sample.full_batch), dtype=object - ) - rollout_sample.param_version = self.current_param_version - rollout_sample.rollout_status = await self.get_statistics() - rollout_sample.agent_loop_output_list = [] + rollout_sample.rollout_status = await self.get_statistics() - success = await self.message_queue_client.put_sample( - sample=ray.cloudpickle.dumps(rollout_sample), - param_version=rollout_sample.param_version, - ) - if success: - self.total_generated_samples += 1 - else: - self.dropped_stale_samples += 1 + success = await self.message_queue_client.put_sample( + sample=ray.cloudpickle.dumps(rollout_sample), + ) + if success: + self.total_generated_samples += 1 else: - rollout_sample.agent_loop_output_list = ret - await self.cancel_queue.put(rollout_sample) - + self.dropped_stale_samples += 1 self.processed_sample_count += 1 async def _streaming_generation_main(self): @@ -617,8 +526,8 @@ async def _streaming_generation_main(self): print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}") # Start sample feed coroutine, streaming process coroutine - self.feed_task = asyncio.create_task(self._feed_samples()) - self.processor_task = asyncio.create_task(self._processor_worker()) + self.feed_task = safe_create_task(self._feed_samples(), name="feed_task") + self.processor_task = safe_create_task(self._processor_worker(), name="processor_task") try: # Wait for sample feed to complete @@ -641,20 +550,27 @@ async def _streaming_generation_main(self): await self.processor_task print("[FullyAsyncRollouter] Streaming process completed") + await self.pending_queue.join() + print("[FullyAsyncRollouter] pending_queue joined") + except Exception as e: - print(f"[FullyAsyncRollouter] Streaming process exception:{e}") + print(f"[FullyAsyncRollouter] Streaming process exception: {e}") + raise e finally: - if self.processor_task: + if self.feed_task and not self.feed_task.done(): + self.feed_task.cancel() + await asyncio.gather(self.feed_task, return_exceptions=True) + + if self.processor_task and not self.processor_task.done(): self.processor_task.cancel() + await asyncio.gather(self.processor_task, return_exceptions=True) - await asyncio.gather(self.processor_task, return_exceptions=True) + self.feed_task = None + self.processor_task = None - # Send a finish signal - await self.message_queue_client.put_sample( - sample=None, - param_version=self.current_param_version, - ) + # Send a finish signal + await self.message_queue_client.put_sample(sample=None) async with self.lock: self.running = False @@ -676,8 +592,8 @@ async def fit(self): self.running = True # Create the main asynchronous task - generation_task = asyncio.create_task(self._streaming_generation_main()) - monitor_task = asyncio.create_task(self._async_monitor_loop()) + generation_task = safe_create_task(self._streaming_generation_main(), name="generation_task") + monitor_task = safe_create_task(self._async_monitor_loop(), name="monitor_task") try: # Run build and monitoring tasks concurrently @@ -718,11 +634,11 @@ async def _async_monitor_loop(self): last_stats_time = current_time # Trigger rollout recovery - if self.monitor_loop_trigger: - if not await self._should_pause_generation(): - async with self.lock: - self.paused = False - self.condition.notify_all() + if self.paused and not await self._should_pause_generation(): + async with self.lock: + self.paused = False + print("[FullyAsyncRollouter][ShouldPause] notify all wait tasks.") + self.condition.notify_all() async def _should_pause_generation(self) -> bool: """Determine whether the build should be paused""" @@ -748,36 +664,6 @@ async def _should_pause_generation(self) -> bool: return False - async def pause(self): - """pause rollout""" - print("[FullyAsyncRollouter][Public][Pause] partial rollout:", self.config.async_training.partial_rollout) - async with self.lock: - self.paused = True - # Cancel all rollout tasks - if self.config.async_training.partial_rollout: - await self.async_rollout_manager.cancel() - print("[FullyAsyncRollouter][Public][Pause] Unfinished rollout tasks canceled") - if self.active_tasks: - await asyncio.gather(*self.active_tasks, return_exceptions=True) - self.active_tasks.clear() - print("[FullyAsyncRollouter][Public][Pause] All active tasks completed") - print("[FullyAsyncRollouter][Public][Pause] Prefix cache reset") - # Always clear KV cache to release GPU memory during weight synchronization, - # regardless of partial_rollout setting. - await self.async_rollout_manager.clear_kv_cache() - self.monitor_loop_trigger = False - - async def resume(self, dependency_ref: ObjectRef = None): - if dependency_ref is not None: - ray.get(dependency_ref) - print("[FullyAsyncRollouter][Public][Resume]") - async with self.lock: - if self.config.async_training.partial_rollout: - await self.async_rollout_manager.resume() - self.paused = False - self.monitor_loop_trigger = True - self.condition.notify_all() - async def get_statistics(self) -> dict: queue_stats = self.message_queue_client.get_statistics_sync() @@ -785,10 +671,8 @@ async def get_statistics(self) -> dict: # monitor stats "monitor/active_tasks_size": len(self.active_tasks), "monitor/queue/pending_queue_size": self.pending_queue.qsize(), - "monitor/queue/cancel_queue_size": self.cancel_queue.qsize(), "monitor/queue/mq_queue_size": queue_stats["queue_size"], # counting stats - "count/current_param_version": self.current_param_version, "count/total_generated_samples": self.total_generated_samples, "count/staleness_samples": self.staleness_samples, "count/dropped_stale_samples": self.dropped_stale_samples, diff --git a/verl/experimental/fully_async_policy/fully_async_trainer.py b/verl/experimental/fully_async_policy/fully_async_trainer.py index 9519c594dbd..e9728064ee6 100644 --- a/verl/experimental/fully_async_policy/fully_async_trainer.py +++ b/verl/experimental/fully_async_policy/fully_async_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import time from datetime import datetime @@ -19,9 +20,11 @@ from typing import Any import ray +from omegaconf import OmegaConf, open_dict from tqdm import tqdm from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager from verl.experimental.fully_async_policy.detach_utils import ( MetricsAggregator, ValidateMetrics, @@ -34,7 +37,11 @@ from verl.trainer.ppo.ray_trainer import ResourcePoolManager from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass from verl.utils.debug import marked_timer +from verl.utils.tracking import Tracking + +logger = logging.getLogger(__name__) class TrainingStopException(Exception): @@ -99,7 +106,6 @@ def __init__( self.epoch = 0 self.max_steps_duration = 0 self.progress_bar = None - self.logger = None self.is_last_step = False self.prev_step_profile = False self.curr_step_profile = False @@ -112,24 +118,26 @@ def __init__( self.reward_tensor = None self.reward_extra_infos_dict = {} + self.logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + # ==================== fully async config ==================== self.message_queue_client = None - self.param_synchronizer = None # Statistics - # we start from step 1 - self.global_steps = 1 self.local_trigger_step = 1 self.processed_samples = 0 - self.stale_samples_processed = 0 self.stale_trajectory_processed = 0 self.current_param_version = 0 self.total_train_steps = None self.progress_bar = None self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step self.last_ckpt_version = 0 - self.train_val_metrics = None self.train_role = Role.ActorRollout if config.async_training.use_trainer_do_validate else Role.Actor # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. @@ -171,24 +179,52 @@ def __init__( drop_last=False, collate_fn=collate_fn, ) + # Reference to rollouter for parameter synchronization + self.rollouter = None + self.checkpoint_manager = None + + # when use_trainer_do_validate == Ture, use colocate_checkpoint_manager to sync params + self.colocate_checkpoint_manager = None + + def _setup_checkpoint_manager(self, rollouter): + """Setup checkpoint manager after rollouter is initialized""" + replicas = ray.get(rollouter.get_replicas.remote()) + checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) + self.checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, trainer=self.actor_wg, replicas=replicas + ) + print("[FullyAsyncTrainer] Checkpoint manager initialized") def set_message_queue_client(self, message_queue_client: MessageQueueClient): """Set message queue client""" self.message_queue_client = message_queue_client - def set_parameter_synchronizer(self, param_synchronizer): - """Set parameter synchronizer""" - self.param_synchronizer = param_synchronizer + def set_rollouter(self, rollouter): + """Set rollouter reference for parameter synchronization""" + self.rollouter = rollouter + # Setup checkpoint manager after rollouter is set + self._setup_checkpoint_manager(rollouter) + + def set_total_train_steps(self, total_training_steps): + self.total_train_steps = total_training_steps + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - def set_total_train_steps(self, total_train_steps): - self.total_train_steps = total_train_steps self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc="Training Progress") def get_actor_wg(self): """Get actor worker group""" return self.actor_wg - def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: + async def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: """ Get samples from message queue and compose gen_batch_output Uses a loop to continuously collect samples until enough are gathered @@ -233,7 +269,7 @@ def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: print( f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, " - f"total wait time: {total_wait_time:.2f} seconds." + f"total wait time: {total_wait_time:.2f} seconds. " f"mq_len: {queue_len}" ) @@ -281,13 +317,18 @@ async def init_workers(self): 1. Ray resource pools from configuration 2. Worker groups for each role (actor, critic, etc.) """ - # self._init_async_objects() self._init_resource_pools() self._create_worker_classes() self._init_worker_groups() self._init_models() + self._init_reward_loop() await self._init_async_rollout_manager() + def _init_reward_loop(self): + if self.config.async_training.use_trainer_do_validate: + print("[FullyAsyncTrainer] Init reward loop") + super()._init_reward_loop() + async def _init_async_rollout_manager(self): # use async rollout do validate print(f"[FullyAsyncTrainer] use_trainer_do_validate: {self.config.async_training.use_trainer_do_validate}") @@ -307,17 +348,41 @@ async def _init_async_rollout_manager(self): # create async rollout manager and request scheduler assert self.config.actor_rollout_ref.rollout.mode == "async" - from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager self.async_rollout_mode = True - self.async_rollout_manager = await FullyAsyncAgentLoopManager.create( + from verl.experimental.agent_loop import AgentLoopManager + + self.async_rollout_manager = await AgentLoopManager.create( config=self.config, worker_group=self.actor_rollout_wg, reward_loop_worker_handles=reward_loop_worker_handles, ) + print("[FullyAsyncTrainer] async_rollout_manager initialized") + + # Modify checkpoint_engine config to use naive backend + checkpoint_engine_cfg = self.config.actor_rollout_ref.rollout.checkpoint_engine + original_backend = checkpoint_engine_cfg.backend + with open_dict(checkpoint_engine_cfg): + checkpoint_engine_cfg.backend = "naive" + checkpoint_engine_config = omega_conf_to_dataclass(checkpoint_engine_cfg) + + print(f"[FullyAsyncTrainer] checkpoint_engine_config: {checkpoint_engine_config}") + + self.colocate_checkpoint_manager = CheckpointEngineManager( + config=checkpoint_engine_config, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + await self.colocate_checkpoint_manager.sleep_replicas() + + # Restore original backend value + with open_dict(checkpoint_engine_cfg): + checkpoint_engine_cfg.backend = original_backend + + print("[FullyAsyncTrainer] colocate_checkpoint_manager initialized") - print("[FullyAsyncTrainer] async_rollout_manager sleep") - await self.async_rollout_manager.sleep() else: print("[FullyAsyncTrainer] Skip async rollout manager (use_trainer_do_validate=False)") @@ -331,24 +396,12 @@ async def fit(self): print("[FullyAsyncTrainer] Starting FullyAsyncTrainer...") if self.message_queue_client is None: raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") - if self.param_synchronizer is None: - raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.") - - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - self.logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) + if self.rollouter is None: + raise ValueError("rollouter not set. Call set_rollouter() first.") self.max_steps_duration = 0 - # get validate data before training - self._log_validation_data() + self.global_steps += 1 # Use queue mode, no need for traditional dataloader iterator # Initialize to get the first batch of data @@ -359,17 +412,11 @@ async def fit(self): print("[FullyAsyncTrainer] Training stopped by queue termination signal") break - # final parameter sync and validate - # 1. waiting remaining validate task - ray.get(self.param_synchronizer.wait_last_valid.remote()) - self._log_validation_data() - # 2. perform addtional parameter_sync and validate if trainer already updated - if self.current_param_version % self.config.rollout.test_freq != 0 or self.local_trigger_step > 1: - await self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps) - ray.get(self.param_synchronizer.wait_last_valid.remote()) - self._log_validation_data() self.progress_bar.close() - self._fit_save_checkpoint() + if self.current_param_version % self.config.trainer.test_freq != 0 or self.local_trigger_step > 1: + await self._fit_update_weights() + await self._fit_validate() + self._fit_save_checkpoint(force=True) async def fit_step(self, batch_dict: dict = None): """ @@ -383,7 +430,6 @@ async def fit_step(self, batch_dict: dict = None): Args: batch_dict: Raw data dictionary """ - print("[FullyAsyncTrainer] fit_step") self.metrics = {"training/global_step": self.global_steps, "training/epoch": self.epoch} self.timing_raw = {} # reward message @@ -391,11 +437,10 @@ async def fit_step(self, batch_dict: dict = None): self.reward_tensor = None self.reward_extra_infos_dict = {} - # self._fit_prepare_step() self._fit_start_profile() with marked_timer("step", self.timing_raw): - batch = self._fit_generate(None) + batch = await self._fit_generate(None) batch = self._fit_compute_reward(batch) batch = self._fit_compute_log_prob(batch) batch = self._fit_compute_ref_log_prob(batch) @@ -403,22 +448,22 @@ async def fit_step(self, batch_dict: dict = None): batch = self._fit_compute_advantage(batch) batch = self._fit_update_critic(batch) batch = self._fit_update_actor(batch) + self._fit_update_local_step() await self._fit_update_weights() self._fit_dump_data(batch) - # self._fit_validate() + await self._fit_validate() self._fit_save_checkpoint() self._fit_stop_profile() self._fit_collect_metrics(batch) self._fit_torch_memory() - # self._fit_experimental(batch) self._fit_postprocess_step() - def _fit_generate(self, batch: DataProto = None) -> DataProto: + async def _fit_generate(self, batch: DataProto = None) -> DataProto | None: metrics = self.metrics timing_raw = self.timing_raw with marked_timer("gen", timing_raw, color="red"): - epoch, batch = self._get_samples_from_queue() + epoch, batch = await self._get_samples_from_queue() if batch is None: raise TrainingStopException("Training terminated: queue returned None") self._collect_metrics_from_samples(batch, metrics) @@ -447,18 +492,7 @@ def _compute_old_log_prob(self, batch: DataProto): self.actor_rollout_wg.clear_cpu_model(self.local_trigger_step) return old_log_prob, old_log_prob_mfu - def _fit_collect_metrics(self, batch): - super()._fit_collect_metrics(batch) - self.metrics_aggregator.add_step_metrics( - metrics=self.metrics, sample_count=self.required_samples, timestamp=time.time() - ) - self._log_validation_data() - - async def _fit_update_weights(self): - # with marked_timer("update_weights", self.timing_raw, color="red"): - # self.checkpoint_manager.update_weights() - - # Trigger parameter synchronization after training step + def _fit_update_local_step(self): time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3] print( f"[FullyAsyncTrainer] global_steps: {self.global_steps} " @@ -466,9 +500,106 @@ async def _fit_update_weights(self): f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} " f"{time_str}" ) - await self._trigger_parameter_sync_after_step() + if self.local_trigger_step < self.trigger_parameter_sync_step: + self.local_trigger_step += 1 + else: + self.current_param_version += 1 + self.local_trigger_step = 1 + + async def _fit_update_weights(self): + if self.local_trigger_step != 1: + return + + with marked_timer("timing_s/param_sync", self.timing_raw): + await self.checkpoint_manager.update_weights(global_steps=self.current_param_version) + print( + f"[FullyAsyncTrainer] _fit_update_weights, " + f"timing_s/param_sync: {self.timing_raw['timing_s/param_sync']:.4f} seconds " + f"self.current_param_version: {self.current_param_version}" + ) + + # Reset staleness in rollouter + timing_raw = ray.get(self.rollouter.reset_staleness.remote()) + self.logger.log( + data=timing_raw, + step=self.current_param_version, + ) + + # Log aggregated training metrics + self.logger.log( + data=self.metrics_aggregator.get_aggregated_metrics(), + step=self.current_param_version, + ) + self.metrics_aggregator.reset() + + async def _validate_process(self): + """Run trainer-side validation using async rollout manager""" + if self.config.async_training.use_trainer_do_validate: + print("[FullyAsyncTrainer] _validate_process") + from verl.utils.profiler import marked_timer + + # Wake up rollouter replicas and sync weights + print("[FullyAsyncTrainer] wake up replicas before validation") + await self.colocate_checkpoint_manager.update_weights(global_steps=self.current_param_version) + + with marked_timer("trainer/validate_time", self.timing_raw): + train_val_metrics = self._validate(True) + + # Sleep rollouter replicas to free GPU memory for validation + print("[FullyAsyncTrainer] sleep replicas after validation") + await self.colocate_checkpoint_manager.sleep_replicas() + + print(f"[FullyAsyncTrainer] validate timing: {self.timing_raw['trainer/validate_time']}") + return train_val_metrics + else: + print("[FullyAsyncTrainer] _validate_process without async_rollout_manager") + return None + + async def _fit_validate(self, val_before_train=False): + if self.local_trigger_step != 1: + return + + # Check if validation is needed + need_validate = ( + self.config.trainer.test_freq > 0 + and self.current_param_version % self.config.trainer.test_freq == 0 + and self.current_param_version > 0 + ) + # Skip validation if not needed and not validation before training + if not need_validate and not val_before_train: + return + + # Trigger rollouter validation and get future + val_future = self.rollouter.do_validate.remote() + + # Run trainer-side validation + train_val_metrics = await self._validate_process() + + # Wait for rollouter validation result and log + val_metrics: ValidateMetrics = ray.get(val_future) + if train_val_metrics: + # Merge trainer and rollouter validation results + with marked_timer("timing_s/merge_val", self.timing_raw): + new_metrics = self._merge_validation_results(train_val_metrics, val_metrics.metrics) + if new_metrics: + self.logger.log(data=new_metrics, step=self.current_param_version) + pprint( + f"[FullyAsyncTrainer] parameter version: {self.current_param_version} " + f"Validation metrics: {new_metrics}, timing: {self.timing_raw['timing_s/merge_val']}" + ) + else: + if val_metrics.metrics: + self.logger.log(data=val_metrics.metrics, step=self.current_param_version) + pprint( + f"[FullyAsyncTrainer] parameter version: {self.current_param_version} " + f"Validation metrics: {val_metrics.metrics}" + ) + self.logger.log(data=val_metrics.timing_raw, step=self.current_param_version) + + def _fit_save_checkpoint(self, force=False): + if self.current_param_version == self.last_ckpt_version: + return - def _fit_save_checkpoint(self): timing_raw = self.timing_raw # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. esi_close_to_expiration = should_save_ckpt_esi( @@ -483,20 +614,25 @@ def _fit_save_checkpoint(self): # 3. The current step number is a multiple of the save frequency. # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. if self.config.trainer.save_freq > 0 and ( - self.current_param_version % self.config.trainer.save_freq == 0 or esi_close_to_expiration + force and self.current_param_version % self.config.trainer.save_freq == 0 or esi_close_to_expiration ): if esi_close_to_expiration: print("Force saving checkpoint: ESI instance expiration approaching.") with marked_timer("save_checkpoint", timing_raw, color="green"): # sleep replicas to avoid OOM during checkpoint saving - # self.checkpoint_manager.sleep_replicas() self._save_checkpoint() - # wake replicas to avoid OOM during checkpoint saving - # self.checkpoint_manager.update_weights() + self.last_ckpt_version = self.current_param_version def _fit_postprocess_step(self): self.global_steps += 1 + self.metrics_aggregator.add_step_metrics( + metrics=self.metrics, sample_count=self.required_samples, timestamp=time.time() + ) + + if self.local_trigger_step == 1: + self.progress_bar.update(1) + def _save_checkpoint(self): # Warning: Currently, to align the training process and metrics of colocate, # we use current_param_version instead of global step. @@ -552,7 +688,7 @@ def _save_checkpoint(self): self.current_param_version, max_ckpt_to_keep=max_critic_ckpt_to_keep, ) - ray.get(self.param_synchronizer.rollouter_save_checkpoint.remote(local_global_step_folder)) + ray.get(self.rollouter.save_checkpoint.remote(local_global_step_folder)) # latest checkpointed iteration tracker (for atomic usage) local_latest_checkpointed_iteration = os.path.join( self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" @@ -560,7 +696,7 @@ def _save_checkpoint(self): with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.current_param_version)) - def load_checkpoint(self): + async def load_checkpoint(self): if self.config.trainer.resume_mode == "disable": return 0 @@ -610,6 +746,11 @@ def load_checkpoint(self): self.critic_wg.load_checkpoint( critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load ) + + if self.colocate_checkpoint_manager: + await self.colocate_checkpoint_manager.update_weights(self.current_param_version) + await self.colocate_checkpoint_manager.sleep_replicas() + return self.current_param_version def _collect_metrics_from_samples(self, batch, metrics): @@ -617,15 +758,11 @@ def _collect_metrics_from_samples(self, batch, metrics): Collect metrics from samples """ if hasattr(batch, "meta_info") and batch.meta_info: - samples_param_versions = batch.meta_info["rollout_param_versions"] - stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1) - self.stale_samples_processed += stale_count trajectory_param_versions = batch.meta_info["trajectory_param_versions"] stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1) self.stale_trajectory_processed += stale_traj_count metrics.update( { - "fully_async/count/stale_samples_processed": self.stale_samples_processed, "fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed, "fully_async/count/current_param_version": self.current_param_version, } @@ -633,92 +770,3 @@ def _collect_metrics_from_samples(self, batch, metrics): for key, value in batch.meta_info.items(): if key.startswith("fully_async") or key.startswith("timing_s"): metrics[key] = value - - async def _trigger_parameter_sync_after_step(self, validate: bool = False): - """ - Trigger parameter synchronization after training step - This ensures rollouter always uses the latest trained parameters - """ - if self.local_trigger_step < self.trigger_parameter_sync_step and not validate: - self.local_trigger_step += 1 - return - - self.current_param_version += 1 - self.local_trigger_step = 1 - self.logger.log( - data=self.metrics_aggregator.get_aggregated_metrics(), - step=self.current_param_version, - ) - self.progress_bar.update(1) - self.metrics_aggregator.reset() - timing_param_sync = {} - with marked_timer("timing_s/wait_last_valid", timing_param_sync): - ray.get(self.param_synchronizer.wait_last_valid.remote()) - with marked_timer("timing_s/param_sync", timing_param_sync): - ray.get( - self.param_synchronizer.sync_weights.remote( - self.current_param_version, - validate=validate, - global_steps=self.global_steps, - use_trainer_do_validate=self.config.async_training.use_trainer_do_validate, - ) - ) - - # do trainer validate - do_validate_param = ( - self.config.rollout.test_freq > 0 - and self.current_param_version % self.config.rollout.test_freq == 0 - and self.current_param_version > 0 - ) - print(f"do_validate_param: {do_validate_param}") - if do_validate_param and self.config.async_training.use_trainer_do_validate: - print(f"[FullyAsyncTrainer] validate param version: {self.current_param_version}") - await self._validate_process() - else: - self.train_val_metrics = None - self.logger.log(data=timing_param_sync, step=self.current_param_version) - - def _log_validation_data(self): - """ - Log validation data - """ - val_data = self.message_queue_client.get_validate_sync() - if not val_data: - return - - val_metrics: ValidateMetrics = ray.cloudpickle.loads(val_data) - if self.train_val_metrics and self.config.async_training.use_trainer_do_validate: - # merge info - timing_param_sync = {} - with marked_timer("timing_s/merge_val", timing_param_sync): - new_metrics = self._merge_validation_results(self.train_val_metrics, val_metrics.metrics) - if new_metrics: - self.logger.log(data=new_metrics, step=val_metrics.param_version) - pprint( - f"[FullyAsyncTrainer] parameter version: {val_metrics.param_version} " - f"Validation metrics: {new_metrics}, timing_param_sync: {timing_param_sync['timing_s/merge_val']}" - ) - self.logger.log(data=val_metrics.timing_raw, step=val_metrics.param_version) - else: - if val_metrics.metrics: - self.logger.log(data=val_metrics.metrics, step=val_metrics.param_version) - pprint( - f"[FullyAsyncTrainer] parameter version: {val_metrics.param_version} " - f"Validation metrics: {val_metrics.metrics}" - ) - self.logger.log(data=val_metrics.timing_raw, step=val_metrics.param_version) - - async def _validate_process(self): - if self.config.async_training.use_trainer_do_validate: - print("[FullyAsyncTrainer] _validate_process") - from verl.utils.profiler import marked_timer - - timing_raw = {} - await self.async_rollout_manager.wake_up() - with marked_timer("trainer/validate_time", timing_raw): - self.train_val_metrics = self._validate(True) - await self.async_rollout_manager.sleep() - print(f"[FullyAsyncTrainer] validate timing_raw validate: {timing_raw['trainer/validate_time']}") - else: - self.train_val_metrics = None - print("[FullyAsyncTrainer] _validate_process without async_rollout_manager") diff --git a/verl/experimental/fully_async_policy/megatron_utils.py b/verl/experimental/fully_async_policy/megatron_utils.py deleted file mode 100644 index 9f5380f25c5..00000000000 --- a/verl/experimental/fully_async_policy/megatron_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core.distributed import DistributedDataParallel as DDP - - -@torch.no_grad() -def copy_megatron_model_to_cpu(models): - """ - Copy Megatron model parameters to CPU memory (non-destructive copy). - Unlike offload_megatron_model_to_cpu which moves data, this function creates - independent copies on CPU while keeping GPU data intact. - - Args: - models: List of model chunks (DDP-wrapped or unwrapped) - - Returns: - dict: CPU state containing copied parameters and buffers - """ - cpu_state = {} - - for model_idx, model_chunk in enumerate(models): - if isinstance(model_chunk, DDP): - # Handle DDP-wrapped models - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] - buffer_states = [] - - for buffers in model_chunk_all_buffers: - buffer_list = [] - for buffer in buffers: - buffer_state = {} - - # Copy parameter data to CPU - if buffer.param_data.storage().size() > 0: - buffer_state["param_data"] = buffer.param_data.data.cpu().clone().pin_memory() - - buffer_list.append(buffer_state) - buffer_states.append(buffer_list) - - cpu_state[f"model_chunk_{model_idx}"] = {"buffer_states": buffer_states, "is_ddp": True} - else: - # Handle non-DDP models (ref module) - model_state = {} - for name, param in model_chunk.named_parameters(): - param_state = {"data": param.data.cpu().clone().pin_memory()} - model_state[name] = param_state - - cpu_state[f"model_chunk_{model_idx}"] = {"model_state": model_state, "is_ddp": False} - - return cpu_state - - -@torch.no_grad() -def restore_megatron_model_from_cpu(models, cpu_state): - """ - Restore Megatron model parameters from CPU memory back to GPU. - - Args: - models: List of model chunks to restore to - cpu_state: CPU state dict returned from copy_megatron_model_to_cpu - """ - for model_idx, model_chunk in enumerate(models): - chunk_key = f"model_chunk_{model_idx}" - if chunk_key not in cpu_state: - continue - - chunk_state = cpu_state[chunk_key] - - if chunk_state["is_ddp"] and isinstance(model_chunk, DDP): - # Restore DDP buffers - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] - buffer_states = chunk_state["buffer_states"] - - for buffers, buffer_list in zip(model_chunk_all_buffers, buffer_states, strict=False): - for buffer, buffer_state in zip(buffers, buffer_list, strict=False): - # Restore parameter data - if "param_data" in buffer_state: - buffer.param_data.data.copy_(buffer_state["param_data"].to(buffer.param_data.device)) - - elif not chunk_state["is_ddp"] and not isinstance(model_chunk, DDP): - # Restore non-DDP models - model_state = chunk_state["model_state"] - for name, param in model_chunk.named_parameters(): - if name in model_state: - param_state = model_state[name] - param.data.copy_(param_state["data"].to(param.device)) diff --git a/verl/experimental/fully_async_policy/message_queue.py b/verl/experimental/fully_async_policy/message_queue.py index f5dcec566bc..8d419d36449 100644 --- a/verl/experimental/fully_async_policy/message_queue.py +++ b/verl/experimental/fully_async_policy/message_queue.py @@ -35,18 +35,9 @@ def __init__(self, config: DictConfig, max_queue_size: int = 1000): raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}") self.max_queue_size = int(max_queue_size) self.queue = deque(maxlen=self.max_queue_size) - self.current_param_version = 0 self.val_queue = deque() - try: - if hasattr(config, "async_training") and config.async_training is not None: - self.staleness_threshold = getattr(config.async_training, "staleness_threshold", 3) - else: - self.staleness_threshold = 3 - except (AttributeError, RecursionError): - self.staleness_threshold = 3 - # Asyncio for message handling self.running = True @@ -59,18 +50,14 @@ def __init__(self, config: DictConfig, max_queue_size: int = 1000): self.total_consumed = 0 self.dropped_samples = 0 - print( - f"[MessageQueue] initialized with max_queue_size={max_queue_size}, " - f"staleness_threshold={self.staleness_threshold}" - ) + print(f"[MessageQueue] initialized with max_queue_size={max_queue_size}") - async def put_sample(self, sample: Any, param_version: int) -> bool: + async def put_sample(self, sample: Any) -> bool: """ Put a batch sample into the queue Args: sample: Sample data - param_version: Parameter version number Returns: bool: Whether the sample was successfully put into the queue @@ -115,13 +102,6 @@ async def get_sample(self) -> Any | None: self.total_consumed += 1 return data, len(self.queue) - async def update_param_version(self, version: int): - """Update current parameter version""" - async with self._lock: - old_version = self.current_param_version - self.current_param_version = version - print(f"Parameter version updated from {old_version} to {version}") - async def get_queue_size(self) -> int: """Get current queue length""" async with self._lock: @@ -135,8 +115,6 @@ async def get_statistics(self) -> dict[str, Any]: "total_produced": self.total_produced, "total_consumed": self.total_consumed, "dropped_samples": self.dropped_samples, - "current_param_version": self.current_param_version, - "staleness_threshold": self.staleness_threshold, "max_queue_size": self.max_queue_size, } @@ -205,9 +183,9 @@ class MessageQueueClient: def __init__(self, queue_actor: Any): self.queue_actor = queue_actor - async def put_sample(self, sample: Any, param_version: int) -> bool: + async def put_sample(self, sample: Any) -> bool: """Put batch into queue (async)""" - future = self.queue_actor.put_sample.remote(sample, param_version) + future = self.queue_actor.put_sample.remote(sample) return await asyncio.wrap_future(future.future()) async def put_validate(self, data: Any) -> bool: @@ -247,11 +225,6 @@ async def get_memory_usage(self) -> dict: future = self.queue_actor.get_memory_usage.remote() return await asyncio.wrap_future(future.future()) - # Synchronous version of the method (deprecated) - def put_sample_sync(self, sample: Any, param_version: int) -> bool: - """Put batch into queue (sync - deprecated, use put_sample instead)""" - return ray.get(self.queue_actor.put_sample.remote(sample, param_version)) - def get_sample_sync(self) -> Any | None: """Get single sample from queue (sync - deprecated, use get_sample instead)""" return ray.get(self.queue_actor.get_sample.remote()) @@ -259,7 +232,3 @@ def get_sample_sync(self) -> Any | None: def get_statistics_sync(self) -> dict[str, Any]: """Get statistics (sync - deprecated, use get_statistics instead)""" return ray.get(self.queue_actor.get_statistics.remote()) - - def update_param_version_sync(self, version: int): - """Update parameter version (async)""" - return ray.get(self.queue_actor.update_param_version.remote(version)) diff --git a/verl/experimental/fully_async_policy/param_sync.py b/verl/experimental/fully_async_policy/param_sync.py deleted file mode 100644 index 9ec24feaf60..00000000000 --- a/verl/experimental/fully_async_policy/param_sync.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -import ray -from ray.util.collective import collective - -from verl.checkpoint_engine import CheckpointEngineManager -from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import get_nccl_backend - -logger = logging.getLogger(__name__) - - -@ray.remote -class ParameterSynchronizer: - """ - Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout - Based on the mature synchronization mode implementation of one_step_off_policy - Merges the functions of the original multiple synchronizer classes - """ - - def __init__(self, config, trainer, rollouter, mq): - self.config = config - self.trainer = trainer - self.rollouter = rollouter - self.mq_client = mq - self.actor_wg = ray.get(trainer.get_actor_wg.remote()) - self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote()) - - # Basic attributes - self.weights_info = None - self.sync_group_initialized = False - self.sync_group_name = "actor_rollout" - self.wait_last_update = None - self.wait_last_resume = None - self.validate_task = None - - # Statistics - self.current_version = 0 - - replicas = ray.get(rollouter.get_replicas.remote()) - checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) - self.checkpoint_manager = CheckpointEngineManager( - config=checkpoint_engine_config, trainer=self.actor_wg, replicas=replicas - ) - - def get_current_param_version(self) -> int: - """Get current parameter version number""" - return self.current_version - - def get_weights_info(self): - """Get weights info""" - return self.weights_info - - def _init_weights_info(self): - self.weights_info = self.actor_wg.get_actor_weights_info()[0] - self.rollout_wg.set_actor_weights_info(self.weights_info) - - def _init_sync_group(self): - print("[ParameterSynchronizer] Initializing parameter synchronization group...") - actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers - n_workers = len(self.actor_wg.workers + self.rollout_wg.workers) - if self.config.trainer.device == "npu": - master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote()).strip("[]") - master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote()) - self.actor_wg.create_weight_sync_group( - master_address, - master_port, - 0, - n_workers, - ) - ray.get( - self.rollout_wg.create_weight_sync_group( - master_address, - master_port, - len(self.actor_wg.workers), - n_workers, - ) - ) - else: - collective.create_collective_group( - actor_rollout_workers, - n_workers, - list(range(0, n_workers)), - backend=get_nccl_backend(), - group_name=self.sync_group_name, - ) - - def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_validate=False): - """Sync weights between trainer and rollouter, and update parameter version""" - start_time = time.time() - - self.current_version = version - ray.get(self.rollouter.pause.remote()) - - print(f"[ParameterSynchronizer] rollout paused. cost {time.time() - start_time:.2f} seconds") - # Update MQ version - self.mq_client.update_param_version_sync(version) - - pause_time = time.time() - - # sync weights - # For sglang, always use sync_rollout_weights instead of sync_rollout_weights_by_checkpoint - - self.checkpoint_manager.update_weights(global_steps) - end_time = time.time() - print( - f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds, " - f"pause:{pause_time - start_time:.2f}s, sync:{end_time - pause_time:.2f}s" - ) - # async train do validate - print(f"[ParameterSynchronizer] validate: {validate}, use_trainer_do_validate: {use_trainer_do_validate}") - if validate and use_trainer_do_validate: - print("[ParameterSynchronizer] use trainer to do validate") - self.validate_task = self.trainer._validate_process.remote() - else: - self.validate_task = None - # Async Update rollout version & validation - self.wait_last_update = self.rollouter.update_param_version.remote( - version, validate, global_steps, use_trainer_do_validate - ) - self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update) - - def wait_last_valid(self): - print("[ParameterSynchronizer] Waiting last sync and validate...") - start_time = time.time() - if self.wait_last_update: - ray.get(self.wait_last_update) - if self.wait_last_resume: - ray.get(self.wait_last_resume) - if self.validate_task: - ray.get(self.validate_task) - print(f"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds") - - def rollouter_save_checkpoint(self, local_global_step_folder: str): - """Trigger rollout to save checkpoint(dataloader)""" - print(f"[ParameterSynchronizer] Triggering checkpoint save at {local_global_step_folder} ...") - return ray.get(self.rollouter.save_checkpoint.remote(local_global_step_folder)) diff --git a/verl/experimental/fully_async_policy/sglang_rollout/__init__.py b/verl/experimental/fully_async_policy/sglang_rollout/__init__.py deleted file mode 100644 index 9cd3ed5b8e9..00000000000 --- a/verl/experimental/fully_async_policy/sglang_rollout/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py b/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py deleted file mode 100644 index 89097880f4d..00000000000 --- a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import logging -from typing import Any, Optional - -import ray -import torch -from ray.actor import ActorHandle - -from verl.workers.config import HFModelConfig, RolloutConfig -from verl.workers.rollout.replica import RolloutMode -from verl.workers.rollout.sglang_rollout.async_sglang_server import ( - SGLangHttpServer, - SGLangReplica, -) - -logger = logging.getLogger(__file__) -logger.setLevel(logging.INFO) - - -class SGLangHttpServerForPartial(SGLangHttpServer): - def __init__( - self, - config: RolloutConfig, - model_config: HFModelConfig, - rollout_mode: RolloutMode, - workers: list[ActorHandle], - replica_rank: int, - node_rank: int, - nnodes: int, - cuda_visible_devices: str, - base_gpu_id: int, - ): - super().__init__( - config=config, - model_config=model_config, - rollout_mode=rollout_mode, - workers=workers, - replica_rank=replica_rank, - node_rank=node_rank, - nnodes=nnodes, - cuda_visible_devices=cuda_visible_devices, - base_gpu_id=base_gpu_id, - ) - - # for cancel LLMServer - self.paused = False - self.lock = asyncio.Lock() - self.cancel_event: dict[str, asyncio.Event] = {} - self.req_output: dict[str, Optional[dict[str, Any]]] = {} - - async def _generate_step( - self, - prompt_ids: torch.Tensor, - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - video_data: Optional[list[Any]] = None, - ) -> None: - sampling_params = dict(sampling_params) - - max_new_tokens = min( - self.config.response_length, - self.config.max_model_len - len(prompt_ids) - 1, - ) - sampling_params["max_new_tokens"] = max_new_tokens - - sampling_params.setdefault( - "repetition_penalty", - self.config.get("repetition_penalty", 1.0), - ) - - sampling_params.pop("logprobs", None) - return_logprob = True - from sglang.srt.managers.io_struct import GenerateReqInput - - if video_data is not None and len(video_data) > 0: - logger.warning( - f"Request {request_id} received video_data but it is not used. " - "This is to keep consistency with the implementation in " - "verl/workers/rollout/sglang_rollout/async_sglang_server.py. " - "Video data will be ignored." - ) - - request = GenerateReqInput( - rid=request_id, - input_ids=prompt_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - image_data=image_data, - # TODO: support video input for sglang - # video_data=video_data, - ) - - generator = self.tokenizer_manager.generate_request(request, None) - async for output in generator: - self.req_output[request_id] = output - - assert self.req_output[request_id] is not None - - async def generate_for_partial( - self, - prompt_ids: torch.Tensor, - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - video_data: Optional[list[Any]] = None, - ) -> tuple[list[int], list[float], bool]: - async with self.lock: - if self.paused: - return [], [], True - self.req_output[request_id] = None - self.cancel_event[request_id] = asyncio.Event() - cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) - generation_handle = asyncio.create_task( - self._generate_step(prompt_ids, sampling_params, request_id, image_data, video_data) - ) - done, pending = await asyncio.wait( - [generation_handle, cancel_handle], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in done: - await task - - for task in pending: - task.cancel() - async with self.lock: - output = self.req_output.get(request_id) - if output is None: - self.cancel_event.pop(request_id, None) - self.req_output.pop(request_id, None) - return [], [], True - meta_info = output.get("meta_info", {}) - output_token_logprobs = meta_info.get("output_token_logprobs") - - token_ids: list[int] = [] - log_probs: list[float] = [] - - if output_token_logprobs is not None: - for log_prob, token_id, _ in output_token_logprobs: - token_ids.append(int(token_id)) - log_probs.append(float(log_prob)) - else: - token_ids = list(output["output_ids"]) - log_probs = [] - is_cancel = generation_handle not in done - self.cancel_event.pop(request_id, None) - self.req_output.pop(request_id, None) - - return token_ids, log_probs, is_cancel - - async def cancel(self): - async with self.lock: - self.paused = True - for request_id in self.cancel_event: - self.cancel_event[request_id].set() - - async def resume(self): - async with self.lock: - self.paused = False - - -class FullyAsyncSGLangReplica(SGLangReplica): - def __init__( - self, - replica_rank: int, - config: RolloutConfig, - model_config: HFModelConfig, - gpus_per_node: int = 8, - is_reward_model: bool = False, - ): - super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) - self.server_class = ray.remote(SGLangHttpServerForPartial) - - async def cancel(self): - """Cancel each rollout server.""" - await asyncio.gather(*[server.cancel.remote() for server in self.servers]) - - async def resume(self): - """Resume each rollout server.""" - await asyncio.gather(*[server.resume.remote() for server in self.servers]) diff --git a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh index cc936f50dc1..e364a6e591e 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh @@ -179,8 +179,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ rollout.nnodes="${n_nodes_rollout}" \ rollout.n_gpus_per_node="${n_gpus_rollout}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.test_freq=${test_freq} \ - rollout.total_epochs=10 \ + trainer.total_epochs=10 \ + trainer.test_freq=${test_freq} \ async_training.require_batches=${require_batches} \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh index 2a5eb1bb966..7d7e56bdf6a 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh @@ -132,8 +132,8 @@ python3 -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes=$NNODES \ rollout.n_gpus_per_node=$n_gpus_rollout \ rollout.total_rollout_steps=$total_rollout_steps \ - rollout.total_epochs=10 \ - rollout.test_freq=$test_freq \ + trainer.total_epochs=10 \ + trainer.test_freq=$test_freq \ async_training.staleness_threshold=$staleness_threshold \ async_training.trigger_parameter_sync_step=$trigger_parameter_sync_step \ async_training.require_batches=$require_batches \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh index ba8e6804fdb..b90e7eb2985 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh @@ -153,8 +153,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh index 5561208ee6d..922c90f31ae 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh @@ -153,8 +153,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh index 242a5117a5e..d1b9d821e7e 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh @@ -156,7 +156,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES}" \ rollout.n_gpus_per_node="${n_gpus_rollout}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh index ee0657eace7..8bf73be8c85 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh @@ -146,7 +146,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ trainer.logger=['console','tensorboard'] \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ - trainer.val_before_train=False \ + trainer.val_before_train=True \ trainer.save_freq=-1 \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=auto \ @@ -155,8 +155,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES}" \ rollout.n_gpus_per_node="${n_gpus_rollout}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh index 002c1206b8a..b9856540d62 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh @@ -153,8 +153,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh index f01fb8184e7..91985e5f87e 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh @@ -159,8 +159,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh index 2b2143ffa21..aa916e37170 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh @@ -153,8 +153,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh b/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh index 8b32c6e0078..7d4cdffbe92 100644 --- a/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh +++ b/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh @@ -102,8 +102,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES}" \ rollout.n_gpus_per_node="${n_gpus_rollout}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs="${total_epochs}" \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs="${total_epochs}" \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh index ebcb634ff72..3b3dae7f4c1 100644 --- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh +++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh @@ -220,8 +220,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh index c04a09d3266..318d88cb2d7 100644 --- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh +++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh @@ -229,8 +229,8 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ rollout.nnodes="${NNODES_ROLLOUT}" \ rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ rollout.total_rollout_steps="${total_rollout_steps}" \ - rollout.total_epochs=10 \ - rollout.test_freq="${test_freq}" \ + trainer.total_epochs=10 \ + trainer.test_freq="${test_freq}" \ async_training.staleness_threshold="${staleness_threshold}" \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ diff --git a/verl/experimental/fully_async_policy/vllm_rollout/__init__.py b/verl/experimental/fully_async_policy/vllm_rollout/__init__.py deleted file mode 100644 index 9cd3ed5b8e9..00000000000 --- a/verl/experimental/fully_async_policy/vllm_rollout/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py deleted file mode 100644 index d45dbec890f..00000000000 --- a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import logging -from typing import Any, Optional, Sequence - -import ray -from ray.actor import ActorHandle -from vllm import SamplingParams -from vllm.inputs import TokensPrompt -from vllm.outputs import RequestOutput - -from verl.utils.tokenizer import normalize_token_ids -from verl.workers.config import HFModelConfig, RolloutConfig -from verl.workers.rollout.replica import RolloutMode -from verl.workers.rollout.vllm_rollout.vllm_async_server import ( - _qwen2_5_vl_dedup_image_tokens, - vLLMHttpServer, - vLLMReplica, -) - -logger = logging.getLogger(__file__) -logger.setLevel(logging.INFO) - - -class vLLMHttpServerForPartial(vLLMHttpServer): - def __init__( - self, - config: RolloutConfig, - model_config: HFModelConfig, - rollout_mode: RolloutMode, - workers: list[ActorHandle], - replica_rank: int, - node_rank: int, - gpus_per_node: int, - nnodes: int, - cuda_visible_devices: str, - ): - super().__init__( - config, - model_config, - rollout_mode, - workers, - replica_rank, - node_rank, - gpus_per_node, - nnodes, - cuda_visible_devices, - ) - - # for cancel LLMServer - self.paused = False - self.lock = asyncio.Lock() - self.cancel_event: dict[str, asyncio.Event] = {} - self.req_output: dict[str, Optional[RequestOutput]] = {} - - async def _generate_step( - self, - prompt_ids: list[int], - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - video_data: Optional[list[Any]] = None, - ): - prompt_ids = normalize_token_ids(prompt_ids) - max_tokens = self.config.max_model_len - len(prompt_ids) - sampling_params["logprobs"] = 1 - sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) - sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) - prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) - multi_modal_data = {} - if image_data is not None: - multi_modal_data["image"] = image_data - if video_data is not None: - multi_modal_data["video"] = video_data - prompt = TokensPrompt(prompt_token_ids=prompt_ids, multi_modal_data=multi_modal_data) - generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) - - # Get final response - async for output in generator: - self.req_output[request_id] = output - assert self.req_output[request_id] is not None - - async def generate_for_partial( - self, - prompt_ids: list[int], - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - video_data: Optional[list[Any]] = None, - ) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]: - async with self.lock: - if self.paused: - # After cancel, all tasks will return directly and wait for the next submission - return [], [], True - self.req_output[request_id]: Optional[RequestOutput] = None - self.cancel_event[request_id] = asyncio.Event() - cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) - generation_handle = asyncio.create_task( - self._generate_step(prompt_ids, sampling_params, request_id, image_data, video_data) - ) - - done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED) - - for task in done: - await task - - for task in pend: - task.cancel() - - async with self.lock: - if self.req_output[request_id] is None: - return [], [], True - token_ids = self.req_output[request_id].outputs[0].token_ids - log_probs: list[float] = [] - for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs): - # In sampling_params, logprobs is set to 1, which should return 1, - # but in practice there are multiple. Take the log_prob corresponding to token_id - token_id = self.req_output[request_id].outputs[0].token_ids[i] - log_probs.append(x[token_id].logprob) - is_cancel = generation_handle not in done - self.cancel_event.pop(request_id, None) - self.req_output.pop(request_id, None) - return token_ids, log_probs, is_cancel - - async def cancel(self): - async with self.lock: - self.paused = True - for request_id in self.cancel_event: - self.cancel_event[request_id].set() - - async def resume(self): - async with self.lock: - self.paused = False - - -class FullyAsyncvLLMReplica(vLLMReplica): - def __init__( - self, - replica_rank: int, - config: RolloutConfig, - model_config: HFModelConfig, - gpus_per_node: int = 8, - is_reward_model: bool = False, - ): - super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) - self.server_class = ray.remote(vLLMHttpServerForPartial) - - async def cancel(self): - """Cancel each rollout server.""" - await asyncio.gather(*[server.cancel.remote() for server in self.servers]) - - async def resume(self): - """Resume each rollout server.""" - await asyncio.gather(*[server.resume.remote() for server in self.servers]) diff --git a/verl/experimental/one_step_off_policy/agent_loop/__init__.py b/verl/experimental/one_step_off_policy/agent_loop/__init__.py deleted file mode 100644 index a9eb0705e41..00000000000 --- a/verl/experimental/one_step_off_policy/agent_loop/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .agent_loop import OneStepOffAgentLoopManager - -__all__ = [OneStepOffAgentLoopManager] diff --git a/verl/experimental/one_step_off_policy/agent_loop/agent_loop.py b/verl/experimental/one_step_off_policy/agent_loop/agent_loop.py deleted file mode 100644 index 85455d655b2..00000000000 --- a/verl/experimental/one_step_off_policy/agent_loop/agent_loop.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import logging -import os - -import ray - -from verl.experimental.agent_loop.agent_loop import AgentLoopManager -from verl.protocol import DataProto - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class OneStepOffAgentLoopManager(AgentLoopManager): - async def generate_sequences_async(self, prompts: DataProto) -> DataProto: - """Split input batch and dispatch to agent loop workers (async version). - - Args: - prompts (DataProto): Input batch. - - Returns: - DataProto: Output batch. - """ - - chunkes = prompts.chunk(len(self.agent_loop_workers)) - # Use asyncio.gather with ray.get wrapped in asyncio.to_thread to avoid blocking - import asyncio - - outputs = await asyncio.gather( - *[ - asyncio.to_thread(ray.get, worker.generate_sequences.remote(chunk)) - for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) - ] - ) - output = DataProto.concat(outputs) - - # calculate performance metrics - metrics = [output.meta_info.pop("metrics") for output in outputs] # List[List[Dict[str, str]]] - timing = self._performance_metrics(metrics, output) - - output.meta_info = {"timing": timing, **outputs[0].meta_info} - return output - - async def wake_up(self): - await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas]) - - async def sleep(self): - await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas]) - - async def clear_kv_cache(self): - await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) diff --git a/verl/experimental/one_step_off_policy/distributed_utils.py b/verl/experimental/one_step_off_policy/distributed_utils.py deleted file mode 100644 index d117fb96f14..00000000000 --- a/verl/experimental/one_step_off_policy/distributed_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ipaddress -import socket -from datetime import timedelta - -import vllm -from torch.distributed import TCPStore -from vllm.distributed.utils import StatelessProcessGroup - -from verl.utils.device import is_npu_available - - -@staticmethod -def create( - host: str, - port: int, - rank: int, - world_size: int, - data_expiration_seconds: int = 3600, - store_timeout: int = 300, -) -> "StatelessProcessGroup": - """A replacement for `torch.distributed.init_process_group` that does not - pollute the global state. - - If we have process A and process B called `torch.distributed.init_process_group` - to form a group, and then we want to form another group with process A, B, C, - D, it is not possible in PyTorch, because process A and process B have already - formed a group, and process C and process D cannot join that group. This - function is a workaround for this issue. - - `torch.distributed.init_process_group` is a global call, while this function - is a stateless call. It will return a `StatelessProcessGroup` object that can be - used for exchanging metadata. With this function, process A and process B - can call `StatelessProcessGroup.create` to form a group, and then process A, B, - C, and D can call `StatelessProcessGroup.create` to form another group. - - Args: - host: Host address (IPv4 or IPv6). For IPv6, can be in format like "::1" or "[::1]". - port: Port number to bind/listen on. - rank: Rank of the current process. - world_size: Total number of processes in the group. - data_expiration_seconds: Time in seconds before data entries expire (default: 3600). - store_timeout: Timeout in seconds for TCPStore connection (default: 300). - - Returns: - StatelessProcessGroup: A stateless process group instance. - """ # noqa - # Detect address family (IPv4 or IPv6) - try: - # Try to parse as IPv6 first (IPv6 addresses are more specific) - ipaddress.IPv6Address(host.strip("[]")) - address_family = socket.AF_INET6 - except (ipaddress.AddressValueError, ValueError): - address_family = socket.AF_INET - - launch_server = rank == 0 - if launch_server: - # listen on the specified interface (instead of 0.0.0.0 or ::) - listen_socket = socket.socket(address_family, socket.SOCK_STREAM) - listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - - # For IPv6, set IPV6_V6ONLY to only listen on IPv6 (not dual-stack) - # This ensures consistent behavior across different systems - if address_family == socket.AF_INET6: - try: - listen_socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) - except (AttributeError, OSError): - # IPV6_V6ONLY might not be available on all systems - pass - - # Remove brackets from IPv6 address if present (socket.bind handles it) - bind_host = host.strip("[]") - listen_socket.bind((bind_host, port)) - listen_socket.listen() - listen_fd = listen_socket.fileno() - else: - listen_socket = None - listen_fd = None - - store = TCPStore( - host_name=host, - port=port, - world_size=world_size, - is_master=launch_server, - timeout=timedelta(seconds=store_timeout), - use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 - master_listen_fd=listen_fd, - ) - - return StatelessProcessGroup( - rank=rank, - world_size=world_size, - store=store, - socket=listen_socket, - data_expiration_seconds=data_expiration_seconds, - ) - - -vllm.distributed.utils.StatelessProcessGroup.create = create - - -def vllm_stateless_init_process_group(master_address, master_port, rank, world_size, device): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - # NOTE: If it is necessary to support weight synchronization with the sglang backend in the future, - # the following can be used: - # from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator - # from sglang.srt.distributed.utils import statelessprocessgroup - if is_npu_available: - from vllm_ascend.distributed.device_communicators.pyhccl import ( - PyHcclCommunicator as PyNcclCommunicator, - ) - else: - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - - pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl diff --git a/verl/experimental/one_step_off_policy/main_ppo.py b/verl/experimental/one_step_off_policy/main_ppo.py index 0c6ecaedf0e..1f3c5fa0a85 100644 --- a/verl/experimental/one_step_off_policy/main_ppo.py +++ b/verl/experimental/one_step_off_policy/main_ppo.py @@ -24,77 +24,13 @@ import ray from verl.experimental.one_step_off_policy.ray_trainer import OneStepOffRayTrainer -from verl.experimental.one_step_off_policy.utils import need_critic +from verl.experimental.separation.utils import create_resource_pool_manager, create_role_worker_mapping from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler -from verl.trainer.ppo.ray_trainer import ResourcePoolManager -from verl.trainer.ppo.utils import Role, need_reference_policy +from verl.trainer.ppo.utils import need_critic, need_reference_policy from verl.utils.config import validate_config from verl.utils.device import auto_set_device -def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: - """ - Create resource pool manager - - Args: - config: Configuration object - roles: List of roles that need to create resource pools - - Returns: - ResourcePoolManager: Resource pool manager - """ - resource_pool_spec = {} - mapping = {} - - # Actor/Critic resource pool - if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]): - assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" - assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" - - trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes - resource_pool_spec["trainer_pool"] = trainer_pool - - # Map training-related roles to the same resource pool - for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]: - if role in roles: - mapping[role] = "trainer_pool" - - # Rollout resource pool - if Role.Rollout in roles: - assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" - assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" - - return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - -def create_role_worker_mapping(config): - """ - Create mapping from roles to worker classes - - Args: - config: Configuration object - - Returns: - dict: Mapping from roles to worker classes - """ - from verl.experimental.separation.engine_workers import DetachActorWorker - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.engine_workers import TrainingWorker - - ray_worker_group_cls = RayWorkerGroup - - role_worker_mapping = { - Role.Actor: ray.remote(DetachActorWorker), - Role.Critic: ray.remote(TrainingWorker), - } - - # Add reference policy (if KL loss or reward is required) - if need_reference_policy(config): - role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) - - return role_worker_mapping, ray_worker_group_cls - - @ray.remote(num_cpus=10, max_concurrency=100) # please make sure main_task is not scheduled on head class OneStepTaskRunner: def run(self, config): diff --git a/verl/experimental/one_step_off_policy/ray_trainer.py b/verl/experimental/one_step_off_policy/ray_trainer.py index 144632dead5..fca6ce01d12 100644 --- a/verl/experimental/one_step_off_policy/ray_trainer.py +++ b/verl/experimental/one_step_off_policy/ray_trainer.py @@ -31,7 +31,6 @@ from tqdm import tqdm from verl import DataProto -from verl.experimental.one_step_off_policy.utils import need_critic from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.trainer.ppo import core_algos @@ -40,7 +39,7 @@ compute_response_mask, ) from verl.trainer.ppo.reward import extract_reward -from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model from verl.utils.debug import marked_timer from verl.utils.rollout_skip import RolloutSkip from verl.utils.tracking import ValidationGenerationsLogger @@ -179,10 +178,10 @@ def _init_async_rollout_manager(self): # create async rollout manager and request scheduler assert self.config.actor_rollout_ref.rollout.mode == "async" - from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager + from verl.experimental.agent_loop import AgentLoopManager self.async_rollout_mode = True - self.async_rollout_manager = OneStepOffAgentLoopManager.create( + self.async_rollout_manager = AgentLoopManager.create( config=self.config, reward_loop_worker_handles=reward_loop_worker_handles ) @@ -224,7 +223,7 @@ async def _async_gen_next_batch(self, continuous_iterator): # async generation with marked_timer("generate_async", timing_raw, color="purple"): - gen_batch_output = await self.async_rollout_manager.generate_sequences_async(gen_batch_output) + gen_batch_output = await self.async_rollout_manager.generate_sequences(gen_batch_output) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) @@ -347,7 +346,7 @@ async def fit_step(self, batch_data_future, continuous_iterator): # Prevents computations in a certain phase from blocking the entire asynchronous workflow # # The purpose here is to ensure that after triggering - # `self.async_rollout_manager.generate_sequences_async(gen_batch_output)`, + # `self.async_rollout_manager.generate_sequences(gen_batch_output)`, # the subsequent relevant logic can proceed in a timely manner await asyncio.sleep(0) batch = self._fit_compute_reward(batch) diff --git a/verl/experimental/one_step_off_policy/utils.py b/verl/experimental/one_step_off_policy/utils.py deleted file mode 100644 index 1879b0672fa..00000000000 --- a/verl/experimental/one_step_off_policy/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2025 Meituan Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from omegaconf import DictConfig - -from verl.trainer.ppo.core_algos import AdvantageEstimator - - -def need_critic(config: DictConfig) -> bool: - """Given a config, do we need critic""" - if config.algorithm.adv_estimator == AdvantageEstimator.GAE: - return True - elif config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, - AdvantageEstimator.GRPO_PASSK, - AdvantageEstimator.REINFORCE_PLUS_PLUS, - # AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy - AdvantageEstimator.RLOO, - AdvantageEstimator.OPO, - AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, - AdvantageEstimator.GPG, - ]: - return False - else: - raise NotImplementedError diff --git a/verl/experimental/reward_loop/reward_loop.py b/verl/experimental/reward_loop/reward_loop.py index 151089ec5c6..ceee8441fa2 100644 --- a/verl/experimental/reward_loop/reward_loop.py +++ b/verl/experimental/reward_loop/reward_loop.py @@ -294,6 +294,7 @@ def _init_reward_loop_workers(self): for i in range(num_workers): # Round-robin scheduling over the all nodes node_id = node_ids[i % len(node_ids)] + self.reward_loop_workers.append( self.reward_loop_workers_class.options( name=f"reward_loop_worker_{i}", diff --git a/verl/experimental/separation/engine_workers.py b/verl/experimental/separation/engine_workers.py index 0f8062ff888..97952b2816f 100644 --- a/verl/experimental/separation/engine_workers.py +++ b/verl/experimental/separation/engine_workers.py @@ -69,14 +69,14 @@ def _get_strategy_handlers(self): strategy = self.config.actor.strategy if strategy in ["fsdp", "fsdp2"]: - from verl.experimental.fully_async_policy.fsdp2_utils import ( + from verl.utils.fsdp_utils import ( fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu, ) self._strategy_handlers = (fsdp2_sharded_save_to_cpu, fsdp2_sharded_load_from_cpu) elif strategy == "megatron": - from verl.experimental.fully_async_policy.megatron_utils import ( + from verl.utils.megatron_utils import ( copy_megatron_model_to_cpu, restore_megatron_model_from_cpu, ) diff --git a/verl/experimental/separation/utils.py b/verl/experimental/separation/utils.py new file mode 100644 index 00000000000..3648cbef385 --- /dev/null +++ b/verl/experimental/separation/utils.py @@ -0,0 +1,92 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ray + +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.utils import Role, need_reference_policy + + +def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: + """ + Create resource pool manager + + Args: + config: Configuration object + roles: List of roles that need to create resource pools + + Returns: + ResourcePoolManager: Resource pool manager + """ + resource_pool_spec = {} + mapping = {} + + # Actor/Critic resource pool + if any(role in roles for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]): + assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" + assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" + + trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes + resource_pool_spec["trainer_pool"] = trainer_pool + + # Map training-related roles to the same resource pool + for role in [Role.Actor, Role.ActorRollout, Role.Critic, Role.RefPolicy, Role.RewardModel]: + if role in roles: + mapping[role] = "trainer_pool" + + # Rollout resource pool + if Role.Rollout in roles: + assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" + + return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + +def create_role_worker_mapping(config): + """ + Create mapping from roles to worker classes + + Args: + config: Configuration object + + Returns: + dict: Mapping from roles to worker classes + """ + # Select worker class based on strategy + if config.trainer.get("use_legacy_worker_impl", "auto") != "disable": + raise NotImplementedError( + "Fully async policy or One step off policy does not support legacy worker implementation" + ) + + from verl.experimental.separation.engine_workers import DetachActorWorker + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.engine_workers import TrainingWorker + + ray_worker_group_cls = RayWorkerGroup + + train_role = Role.Actor + if config.get("async_training", {}).get("use_trainer_do_validate", False): + train_role = Role.ActorRollout + + role_worker_mapping = { + train_role: ray.remote(DetachActorWorker), + Role.Critic: ray.remote(TrainingWorker), + } + + # Add reference policy (if KL loss or reward is required) + if need_reference_policy(config): + role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) + + return role_worker_mapping, ray_worker_group_cls diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 117f2df8d41..2db7549e01d 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -427,8 +427,12 @@ def split(self, num_splits: int): print(f"total_samples: {total_samples}") if total_samples == 0: raise ValueError("Cannot split an empty dataset") + + # Calculate effective sample count after dropping remainders if needed if total_samples % num_splits != 0: - raise ValueError(f"Cannot split dataset size {total_samples} into {num_splits} splits") + total_samples = total_samples - (total_samples % num_splits) + logging.warning(f"Dropping {len(self.dataframe) % num_splits} samples, effective samples: {total_samples}") + split_size = total_samples // num_splits splits = [] diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 8f7f8bef0d0..52ffcb430c0 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -20,7 +20,7 @@ from abc import ABC from collections import OrderedDict from contextlib import contextmanager, nullcontext -from typing import cast +from typing import Optional, cast import torch import torch.distributed as dist @@ -38,7 +38,8 @@ if version.parse(torch.__version__) >= version.parse("2.6"): from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard from torch.distributed.fsdp._fully_shard._fsdp_init import _get_post_forward_mesh_info - from torch.distributed.tensor import Shard + from torch.distributed.tensor import DTensor, Shard + from torch.distributed.tensor._dtensor_spec import DTensorSpec fully_shard_module = torch.distributed.fsdp._fully_shard._fully_shard elif version.parse(torch.__version__) >= version.parse("2.4"): @@ -891,3 +892,103 @@ def merged_lora_context(actor, backup_adapters=False): else: # Fall back to unmerge if no backup was made fsdp_merge_unmerge(actor, do_merge=False) + + +def fsdp2_sharded_save_to_cpu( + model: torch.nn.Module, +) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: + """ + Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. + + Args: + model: FSDP2-wrapped model whose parameters are of DTensor type. + + Returns: + cpu_sharded_state: Dictionary of CPU shards for the current process. + Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) + global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) + """ + cpu_sharded_state = {} + global_spec = None # Record global sharding rules (all parameters follow the same spec) + + for param_name, param in model.named_parameters(): + # Only process sharded parameters of DTensor type (core parameters of FSDP2) + if not isinstance(param, DTensor): + # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data + cpu_tensor = param.detach().cpu() + cpu_sharded_state[param_name] = (cpu_tensor, None) + continue + + # Record global sharding rules (take spec of the first DTensor to ensure consistency) + if global_spec is None: + global_spec = param._spec + assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" + assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" + + # 1. Extract local shard data from the current GPU (_local_tensor) + local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class + # 2. Move to CPU memory and detach from computation graph + local_cpu_tensor = local_gpu_tensor.detach().cpu() + # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged) + cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) + + assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." + return cpu_sharded_state, global_spec + + +def fsdp2_sharded_load_from_cpu( + model: torch.nn.Module, + cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], + target_spec: DTensorSpec, +) -> None: + """ + Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, + keeping sharding rules unchanged. + + Args: + model: FSDP2 model to be restored (must have the same structure as when saved) + cpu_sharded_state: Shard data read from CPU memory by the current process + (from fsdp2_sharded_save_to_cpu) + target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) + """ + # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs) + current_device_mesh = None + for param in model.parameters(): + if isinstance(param, DTensor): + current_device_mesh = param._spec.device_mesh + break + assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" + assert current_device_mesh == target_spec.device_mesh, ( + f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" + ) + + for param_name, param in model.named_parameters(): + # Skip parameters not in the saved state (e.g., newly added parameters) + if param_name not in cpu_sharded_state: + continue + + # Extract CPU shard data and original Spec + local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] + + # Handle different parameter types: DTensor sharded parameters vs. regular parameters + if isinstance(param, DTensor): + # 1. Verify sharding rule consistency (placements must match original Spec) + assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" + assert saved_spec.placements == target_spec.placements, ( + f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" + ) + + # 2. Move CPU shard data to the current GPU (device of param._local_tensor) + target_device = param._local_tensor.device + local_gpu_tensor = local_cpu_tensor.to(target_device) + + # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged) + param._local_tensor.copy_(local_gpu_tensor) + + else: + # Regular parameters: load directly to original device + target_device = param.device + param.data.copy_(local_cpu_tensor.to(target_device)) + + # Process synchronization: ensure all processes complete loading before proceeding + dist.barrier() diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 708c6e24fa5..aa93ea55087 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -1368,3 +1368,85 @@ def patch_engine_mtp(module, model_config): patch_postprocess(module) if model_config.mtp.detach_encoder: patch_mtp_layer_get_embeddings(module) + + +@torch.no_grad() +def copy_megatron_model_to_cpu(models): + """ + Copy Megatron model parameters to CPU memory (non-destructive copy). + Unlike offload_megatron_model_to_cpu which moves data, this function creates + independent copies on CPU while keeping GPU data intact. + + Args: + models: List of model chunks (DDP-wrapped or unwrapped) + + Returns: + dict: CPU state containing copied parameters and buffers + """ + cpu_state = {} + + for model_idx, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + # Handle DDP-wrapped models + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = [] + + for buffers in model_chunk_all_buffers: + buffer_list = [] + for buffer in buffers: + buffer_state = {} + + # Copy parameter data to CPU + if buffer.param_data.storage().size() > 0: + buffer_state["param_data"] = buffer.param_data.data.cpu().clone().pin_memory() + + buffer_list.append(buffer_state) + buffer_states.append(buffer_list) + + cpu_state[f"model_chunk_{model_idx}"] = {"buffer_states": buffer_states, "is_ddp": True} + else: + # Handle non-DDP models (ref module) + model_state = {} + for name, param in model_chunk.named_parameters(): + param_state = {"data": param.data.cpu().clone().pin_memory()} + model_state[name] = param_state + + cpu_state[f"model_chunk_{model_idx}"] = {"model_state": model_state, "is_ddp": False} + + return cpu_state + + +@torch.no_grad() +def restore_megatron_model_from_cpu(models, cpu_state): + """ + Restore Megatron model parameters from CPU memory back to GPU. + + Args: + models: List of model chunks to restore to + cpu_state: CPU state dict returned from copy_megatron_model_to_cpu + """ + for model_idx, model_chunk in enumerate(models): + chunk_key = f"model_chunk_{model_idx}" + if chunk_key not in cpu_state: + continue + + chunk_state = cpu_state[chunk_key] + + if chunk_state["is_ddp"] and isinstance(model_chunk, DDP): + # Restore DDP buffers + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = chunk_state["buffer_states"] + + for buffers, buffer_list in zip(model_chunk_all_buffers, buffer_states, strict=False): + for buffer, buffer_state in zip(buffers, buffer_list, strict=False): + # Restore parameter data + if "param_data" in buffer_state: + buffer.param_data.data.copy_(buffer_state["param_data"].to(buffer.param_data.device)) + + elif not chunk_state["is_ddp"] and not isinstance(model_chunk, DDP): + # Restore non-DDP models + model_state = chunk_state["model_state"] + for name, param in model_chunk.named_parameters(): + if name in model_state: + param_state = model_state[name] + param.data.copy_(param_state["data"].to(param.device)) diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index f6571a4ec5e..969c6208083 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -47,8 +47,8 @@ class TokenOutput(BaseModel): """stop reason: 'completed', 'aborted', or None for unknown""" num_preempted: Optional[int] = None """number of preempted times for metric calculation""" - extra_info: dict[str, Any] = {} - """extra info for rollout""" + extra_fields: dict[str, Any] = {} + """Extra fields for dynamic addition.""" class RolloutMode(Enum): diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index ef01006f51f..c03ada27aba 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -414,7 +414,7 @@ async def generate( log_probs=log_probs, routed_experts=routed_experts, stop_reason=finish_reason, - extra_info={"global_steps": self.global_steps}, + extra_fields={"global_steps": self.global_steps}, ) async def set_global_steps(self, global_steps: int): @@ -510,6 +510,7 @@ async def launch_servers(self): if not self.is_reward_model else f"sglang_server_reward_{self.replica_rank}_{node_rank}" ) + server = self.server_class.options( scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 5a2dbf6fb76..24edab081a2 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -259,7 +259,7 @@ async def generate( if outputs.outputs[0].logprobs is not None: # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] - return TokenOutput(token_ids=token_ids, log_probs=log_probs, extra_info={"global_steps": self.global_steps}) + return TokenOutput(token_ids=token_ids, log_probs=log_probs, extra_fields={"global_steps": self.global_steps}) async def set_global_steps(self, global_steps: int): """Set the global steps of the model weights.""" diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 1f7c5ea9d2d..f7ff10190bc 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -609,7 +609,7 @@ async def generate( routed_experts=routed_experts, stop_reason=stop_reason, num_preempted=num_preempted, - extra_info={"global_steps": self.global_steps}, + extra_fields={"global_steps": self.global_steps}, ) async def wake_up(self): @@ -841,6 +841,7 @@ async def launch_servers(self): if not self.is_reward_model else f"vllm_server_reward_{self.replica_rank}_{node_rank}" ) + server = self.server_class.options( scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id,