diff --git a/.github/workflows/checkpoint_converter.yml b/.github/workflows/checkpoint_converter.yml index cea6dbf16e3..52bd004602b 100644 --- a/.github/workflows/checkpoint_converter.yml +++ b/.github/workflows/checkpoint_converter.yml @@ -51,7 +51,7 @@ jobs: NO_PROXY: "localhost,127.0.0.1" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -81,7 +81,7 @@ jobs: HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable HF_ENDPOINT: "https://hf-mirror.com" container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/dataset.yml b/.github/workflows/dataset.yml index 5e9fa413677..e04f1eb513d 100644 --- a/.github/workflows/dataset.yml +++ b/.github/workflows/dataset.yml @@ -34,7 +34,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/disabled/e2e_prime.yml b/.github/workflows/disabled/e2e_prime.yml index 61c7e86cfb9..0cb85d817c6 100644 --- a/.github/workflows/disabled/e2e_prime.yml +++ b/.github/workflows/disabled/e2e_prime.yml @@ -47,7 +47,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_dapo.yml b/.github/workflows/e2e_dapo.yml index 784e2a071c6..cdf1f426008 100644 --- a/.github/workflows/e2e_dapo.yml +++ b/.github/workflows/e2e_dapo.yml @@ -49,7 +49,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_eval_aime24.yml b/.github/workflows/e2e_eval_aime24.yml index 63532d7ce87..225044a8cb9 100644 --- a/.github/workflows/e2e_eval_aime24.yml +++ b/.github/workflows/e2e_eval_aime24.yml @@ -48,7 +48,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 057d3d0ca30..78a5d8cbce9 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -69,7 +69,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 8fc466ec76a..b68036bff4f 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -240,7 +240,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml index b657dae4f71..7542deebb30 100644 --- a/.github/workflows/e2e_sft.yml +++ b/.github/workflows/e2e_sft.yml @@ -49,7 +49,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml index 0a6f9163dde..cda8edf3ee2 100644 --- a/.github/workflows/kernels.yml +++ b/.github/workflows/kernels.yml @@ -47,7 +47,7 @@ jobs: NO_PROXY: "localhost,127.0.0.1" HF_HUB_ENABLE_HF_TRANSFER: 1 container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index c554f1071e6..7fc18c3c3e5 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -34,7 +34,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -93,7 +93,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/ray_gpu_test.yml b/.github/workflows/ray_gpu_test.yml index 5143965eadb..8d8337fdd71 100644 --- a/.github/workflows/ray_gpu_test.yml +++ b/.github/workflows/ray_gpu_test.yml @@ -36,7 +36,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/sandbox.yml b/.github/workflows/sandbox.yml index 23d7b3ed8de..cd1f497e0db 100644 --- a/.github/workflows/sandbox.yml +++ b/.github/workflows/sandbox.yml @@ -35,7 +35,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/utils_gpu_test.yml b/.github/workflows/utils_gpu_test.yml index fa68797296b..279364e512b 100644 --- a/.github/workflows/utils_gpu_test.yml +++ b/.github/workflows/utils_gpu_test.yml @@ -30,7 +30,7 @@ jobs: runs-on: [L20x8] timeout-minutes: 20 # Increase this timeout value as needed container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh index 2976534698a..e355e2f9c7d 100644 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -4,8 +4,10 @@ set -x # export VLLM_ATTENTION_BACKEND=XFORMERS # For async rollout mode, dataset should return raw chat. -rollout_mode="sync" +rollout_mode="async" +rollout_name="sglang" # sglang or vllm if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 return_raw_chat="True" chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler fi @@ -34,7 +36,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.name=$rollout_name \ actor_rollout_ref.rollout.mode=$rollout_mode \ actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index dde760a9f9b..efb6620df5d 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -13,6 +13,7 @@ MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512} MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} ENGINE=${ENGINE:-vllm} +ROLLOUT_MODE=${ROLLOUT_MODE:-sync} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8} ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} @@ -101,6 +102,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name="${ENGINE}" \ + actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \ actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \ actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \ actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \ diff --git a/tests/workers/rollout/test_async_sglang_server.py b/tests/workers/rollout/test_async_sglang_server.py index 914f527c9e9..ae5b7b8c0c4 100644 --- a/tests/workers/rollout/test_async_sglang_server.py +++ b/tests/workers/rollout/test_async_sglang_server.py @@ -25,14 +25,6 @@ }, ) class TestAsyncSglangServer: - @pytest.fixture - def mock_ray_actor(self): - mock_actor = MagicMock() - mock_actor.execute_method.remote = AsyncMock(return_value={"content": "mocked response"}) - mock_actor.resume.remote = AsyncMock() - mock_actor.offload.remote = AsyncMock() - return mock_actor - @pytest.fixture def server_config(self): return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) @@ -41,11 +33,16 @@ def server_config(self): @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") - async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config, mock_ray_actor): + async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config): mock_list_actors.return_value = [ + {"name": "test_prefixWorkerDict_1:0", "namespace": "test"}, + {"name": "test_prefixWorkerDict_1:1", "namespace": "test"}, {"name": "test_prefixWorkerDict_0:0", "namespace": "test"}, {"name": "test_prefixWorkerDict_0:1", "namespace": "test"}, {"name": "test_prefixWorkerDict_1:2", "namespace": "test"}, + {"name": "test_prefixWorkerDict_1:3", "namespace": "test"}, + {"name": "test_prefixWorkerDict_0:2", "namespace": "test"}, + {"name": "test_prefixWorkerDict_0:3", "namespace": "test"}, ] from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer @@ -53,10 +50,50 @@ async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, se if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr(AsyncSglangServer.__ray_metadata__, "modified_class"): ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class - with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", return_value=mock_ray_actor): - instance = ActualClassToInstantiate(server_config, 2, 0, "test_prefix") + def mock_get_actor_side_effect(name, namespace=None): + # Create a new mock actor for each call + actor_mock = MagicMock() + + # Support .name attribute access + actor_mock.name = name # Use 'name' here + + # Support ['name'] item access by mocking __getitem__ + def getitem_mock(key): + if key == "name": + return name # Use 'name' here + # For other keys, return a new MagicMock to mimic default behavior or raise KeyError + # Returning a MagicMock is consistent with the original error's cause for unmocked keys + return MagicMock(name=f"mock.__getitem__('{key}')") + + actor_mock.__getitem__.side_effect = getitem_mock + + return actor_mock + + # Verify instance.workers is correctly populated + with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", side_effect=mock_get_actor_side_effect): + # Instance 1 + instance = ActualClassToInstantiate(server_config, 4, 0, "test_prefix") + await instance.init_engine() + + assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_0:0" + assert instance.workers[0].name == "test_prefixWorkerDict_0:0" + assert instance.workers[1].name == "test_prefixWorkerDict_0:1" + + # Instance 2 + instance = ActualClassToInstantiate(server_config, 4, 1, "test_prefix") + await instance.init_engine() + + assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_0:2" + assert instance.workers[0].name == "test_prefixWorkerDict_0:2" + assert instance.workers[1].name == "test_prefixWorkerDict_0:3" + # Instance 3 + instance = ActualClassToInstantiate(server_config, 4, 3, "test_prefix") await instance.init_engine() - # Verify instance.workers is correctly populated assert len(instance.workers) == 2 + assert instance.master_worker["name"] == "test_prefixWorkerDict_1:2" + assert instance.workers[0].name == "test_prefixWorkerDict_1:2" + assert instance.workers[1].name == "test_prefixWorkerDict_1:3" diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index bda714f3c4c..5675d33d42d 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -82,9 +82,9 @@ def run(self, config): elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker - actor_rollout_cls = ActorRolloutRefWorker + actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 96b75e7a205..3a3c0c82908 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1430,10 +1430,19 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs): print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}") return self.rollout.execute_method(method, *args, **kwargs) + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def chat_completion(self, json_request): + ret = await self.rollout.chat_completion(json_request) + return ret + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def resume(self): - return self.rollout.resume() + async def wake_up(self): + await self.rollout.wake_up() + # return something to block the caller + return True @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def offload(self): - return self.rollout.offload() + async def sleep(self): + await self.rollout.sleep() + # return something to block the caller + return True diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 11565b5ffeb..01bf207c4bb 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,6 +19,7 @@ import os import time import warnings +from typing import Union import torch import torch.distributed @@ -308,7 +309,7 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage("After building sharding manager", logger=logger) else: raise NotImplementedError("Only vllmRollout is supported with Megatron now") - + print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}") return rollout, sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -362,6 +363,8 @@ def init_model(self): if self._is_rollout: self.rollout, self.sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + # used for sleep/wake_up + self.rollout.sharding_manager = self.sharding_manager log_gpu_memory_usage("After rollout init", logger=logger) if self._is_ref: @@ -548,6 +551,47 @@ def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ck offload_megatron_model_to_cpu(self.actor_module) +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def _build_rollout(self, trust_remote_code=False): + rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) + + # NOTE: rollout is not actually initialized here, it's deferred + # to be initialized by AsyncvLLMServer. + + self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size + self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size + self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size + + # used for sleep/wake_up + rollout.sharding_manager = rollout_sharding_manager + + return rollout, rollout_sharding_manager + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + """Called by ExternalRayDistributedExecutor collective_rpc.""" + if self.vllm_tp_rank == 0 and method != "execute_model": + print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}") + return self.rollout.execute_method(method, *args, **kwargs) + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def chat_completion(self, json_request): + ret = await self.rollout.chat_completion(json_request) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def wake_up(self): + await self.rollout.wake_up() + # return something to block the caller + return True + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def sleep(self): + await self.rollout.sleep() + # return something to block the caller + return True + + class CriticWorker(MegatronWorker): def __init__(self, config): super().__init__() diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index b3a36818871..051dc1e754d 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -36,37 +36,42 @@ def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: st self._dp_rank = dp_rank self.wg_prefix = wg_prefix self.workers = [] + self.master_worker = None async def init_engine(self): + if self.workers: + # avoid init twice + return all_actors = ray.util.list_named_actors(all_namespaces=True) matched_actors = [actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_")] - # TODO support multi node for matched_actor in matched_actors: - current_rank = int(matched_actor["name"].split(":")[-1]) + fields = matched_actor["name"].split(":") + assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}" + pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - # send to all works in this tp group, because sglang is SPMD - if current_rank >= self._dp_rank * self._tp_size and current_rank < (self._dp_rank + 1) * self._tp_size: - self.workers.append(ray.get_actor(**matched_actor)) + if (self._dp_size * pg_index + local_rank) // self._tp_size == self._dp_rank: + worker = ray.get_actor(**matched_actor) + self.workers.append(worker) + if (self._dp_size * pg_index + local_rank) / self._tp_size == self._dp_rank: + self.master_worker = worker async def chat_completion(self, raw_request: Request): request = await raw_request.json() - output_dp_lst = [] - for worker in self.workers: - output_future = worker.execute_method.remote("chat_completion", request) - output_dp_lst.append(output_future) - outputs = await asyncio.gather(*output_dp_lst) - - for output in outputs: - if output is not None: - return JSONResponse(output) - raise RuntimeError("AsyncSglangServer No output from workers self._dp_rank: {self._dp_rank}, self._tp_size: {self._tp_size}, self.workers: {self.workers}") + # only send request to master worker in tp rank 0 + output_future = self.master_worker.chat_completion.remote(request) + [outputs] = await asyncio.gather(output_future) + return JSONResponse(outputs) - async def wake_up(self): + def wake_up(self): + futures = [] for worker in self.workers: - worker.resume.remote() + futures.append(worker.wake_up.remote()) + ray.get(futures) - async def sleep(self): + def sleep(self): + futures = [] for worker in self.workers: - worker.offload.remote() + futures.append(worker.sleep.remote()) + ray.get(futures) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index cf4e20580ce..bd4d635d259 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -17,22 +17,38 @@ import asyncio import logging +import multiprocessing as mp import os import time from contextlib import contextmanager from copy import deepcopy from json import JSONDecodeError -from typing import Union +from typing import List, Optional, Tuple from uuid import uuid4 import numpy as np +import sglang.srt.entrypoints.engine import torch import torch.distributed as dist from omegaconf import DictConfig -from sglang.srt.entrypoints.engine import Engine +from sglang.srt.managers.tokenizer_manager import ( + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromTensorReqInput, +) from sglang.srt.openai_api.protocol import Tool from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.utils import get_ip, get_open_port +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + get_ip, + get_open_port, + is_cuda, + maybe_set_triton_cache_manager, + set_prometheus_multiproc_dir, + set_ulimit, +) from tensordict import TensorDict from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.nn.utils.rnn import pad_sequence @@ -72,6 +88,93 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723 +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + os.environ["CUDA_MODULE_LOADING"] = "AUTO" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer_python", + "0.2.5", + "Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.", + ) + if is_cuda(): + assert_pkg_version( + "sgl-kernel", + "0.1.1", + "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", + ) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config + + +# because chatCompletion is an async method, it makes the whole ray actor be an async actor +# which can not call loop.run_until_complete. So we need to make the engine to be an async class +class AsyncEngine(sglang.srt.entrypoints.engine.Engine): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # default to use dummy load format, which need to reload weights in first time + self._need_reload = True + + async def release_memory_occupation(self): + """Release GPU occupation temporarily.""" + obj = ReleaseMemoryOccupationReqInput() + return await self.tokenizer_manager.release_memory_occupation(obj, None) + + async def resume_memory_occupation(self): + """Resume GPU occupation.""" + + # because __init__ is a sync method, it can not call the async release_memory_occupation + # have to move release_memory_occupation from __init__ to here + if self._need_reload: + await self.release_memory_occupation() + self._need_reload = False + + obj = ResumeMemoryOccupationReqInput() + return await self.tokenizer_manager.resume_memory_occupation(obj, None) + + async def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, torch.Tensor]], # noqa: UP006 + load_format: Optional[str] = None, + flush_cache: bool = True, + ): + """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false + to avoid duplicated cache cleaning operation.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=[MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size)], + load_format=load_format, + flush_cache=flush_cache, + ) + return await self.tokenizer_manager.update_weights_from_tensor(obj, None) + + async def flush_cache(self): + return await self.tokenizer_manager.flush_cache() + + # NOTE(sgm): add for verl. We can optimize it by making # the dataloader yield List[int] without padding. def _pre_process_inputs( @@ -253,7 +356,7 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): if first_rank_in_node: rank = dist.get_rank() os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine( + self._engine = AsyncEngine( model_path=actor_module, dtype=self.config.dtype, mem_fraction_static=self.config.gpu_memory_utilization, @@ -281,9 +384,6 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): self._engine = None self.sharding_manager = None - # offload - if self._tp_rank == 0: - self._engine.release_memory_occupation() self.is_sleep = True def _init_sampling_params(self, **kwargs): @@ -614,7 +714,8 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP # free cache engine if self.config.free_cache_engine and self._engine is not None: - self._engine.flush_cache() + loop = asyncio.get_event_loop() + loop.run_until_complete(self._engine.flush_cache()) return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch) @@ -749,7 +850,7 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): return _req - async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict: + async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, override_n: bool = True, **kwargs) -> dict: generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) if not do_sample: @@ -776,7 +877,7 @@ async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, "n": 1, # if validate, already repeat in ray_trainer } kwargs["max_new_tokens"] = max_new_tokens - if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess + if "n" not in kwargs or (kwargs["n"] > 1 and override_n): # group size is supported in preprocess kwargs["n"] = 1 # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): @@ -931,7 +1032,8 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro # free cache engine if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0: - self._engine.flush_cache() + loop = asyncio.get_event_loop() + loop.run_until_complete(self._engine.flush_cache()) return DataProto( batch=batch, @@ -1018,85 +1120,80 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in return req_list - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - if method == "chat_completion": - json_request = args[0] - - formatted_messages = [] - for msg in json_request["messages"]: - role = msg.get("role", "user") - content = msg.get("content", "") - formatted_messages.append(f"{role}: {content}") - prompt_str = "\n".join(formatted_messages) - - sampling_params_dict = { - "n": json_request.get("n", 1), - "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), - "temperature": json_request.get("temperature", 1.0), - "top_p": json_request.get("top_p", 1.0), - } - output = None - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=prompt_str, - sampling_params=sampling_params_dict, - return_logprob=True, - ) - ) + async def chat_completion(self, json_request): + assert self._tp_rank == 0, "only called in tp rank 0" + _input_ids = [] + _attention_mask = [] + _position_ids = [] + _tool_schemas = [] + _tools_kwargs = {} + + req = AsyncRolloutRequest( + request_id=str(uuid4()), + state=AsyncRolloutRequestStateEnum.PENDING, + messages=[Message.model_validate(msg) for msg in json_request["messages"]], + tools=_tool_schemas, + tools_kwargs=_tools_kwargs, + input_ids=_input_ids, + prompt_ids=_input_ids, + response_ids=[], + attention_mask=_attention_mask, + prompt_attention_mask=_attention_mask, + response_attention_mask=[], + position_ids=_position_ids, + prompt_position_ids=_position_ids, + response_position_ids=[], + loss_mask=[0] * len(_input_ids), + prompt_loss_mask=[0] * len(_input_ids), + response_loss_mask=[], + reward_scores={}, + max_response_len=self.config.response_length, + max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + ) - dist.barrier() - output = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, + # json_request already contains sampling_params + output = await self._handle_engine_call(req, True, False, False, **json_request) + # it can be Dict or AsyncIterator[Dict] + if isinstance(output, dict): + outputs = [output] + else: + outputs = output + + # build openai chat completion format + choices = [] + id = None + for i, content in enumerate(outputs): + choices.append( + { + "index": i, + "message": { + "role": "assistant", + "content": content["text"], + }, + "finish_reason": content["meta_info"]["finish_reason"]["type"], + } ) + id = content["meta_info"]["id"] - # only return value from master rank - if self._tp_rank != 0: - return None - # build openai chat completion format - choices = [] - id = None - for i, content in enumerate(output): - choices.append( - { - "index": i, - "message": { - "role": "assistant", - "content": content["text"], - }, - "finish_reason": content["meta_info"]["finish_reason"]["type"], - } - ) - id = content["meta_info"]["id"] - - return { - "id": "chatcmpl-" + id, - "object": "chat.completion", - "created": int(time.time()), - "model": json_request.get("model", "sglang_model"), - "choices": choices, - } - else: - raise ValueError(f"not supported method : {method}") + return { + "id": "chatcmpl-" + id, + "object": "chat.completion", + "created": int(time.time()), + "model": json_request.get("model", "sglang_model"), + "choices": choices, + } # this function is left for uniform train-inference resharding - def resume(self): + async def wake_up(self): if not self.is_sleep: return - self.sharding_manager.__enter__() # pylint: disable=C2801 - + await self.sharding_manager.wake_up() # pylint: disable=C2801 self.is_sleep = False # this function is left for uniform train-inference resharding - def offload(self): + async def sleep(self): if self.is_sleep: return - - self.sharding_manager.__exit__(None, None, None) + await self.sharding_manager.sleep() self.is_sleep = True diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 3608d932b6c..90ffe46e4f1 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os @@ -99,7 +100,8 @@ def __enter__(self): device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} # Copy, not share memory - self.update_weights(params) + loop = asyncio.get_event_loop() + loop.run_until_complete(self.update_weights(params)) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params @@ -116,7 +118,8 @@ def __enter__(self): @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - self.release_memory() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.release_memory()) log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) self.module.train() @@ -129,9 +132,9 @@ def __exit__(self, exc_type, exc_value, traceback): self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) - def update_weights(self, params): + async def update_weights(self, params): if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.resume_memory_occupation() + await self.inference_engine.resume_memory_occupation() # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update named_tensors = [(k, v) for k, v in params.items()] @@ -151,7 +154,7 @@ def update_weights(self, params): ) if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.update_weights_from_tensor( + await self.inference_engine.update_weights_from_tensor( named_tensors=[ ( name, @@ -162,9 +165,50 @@ def update_weights(self, params): flush_cache=tensor_index == len(named_tensors) - 1, ) - def release_memory(self): + async def release_memory(self): if self.device_mesh["infer_tp"].get_local_rank() == 0: - self.inference_engine.release_memory_occupation() + await self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) + async def wake_up(self): + torch.cuda.empty_cache() + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) + params = self.module.state_dict() + log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + device = torch.cuda.current_device() # used when fsdp2 set cpu_offload_policy + params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} + # Copy, not share memory + await self.update_weights(params) + log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) + + del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) + torch.cuda.empty_cache() + log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.gen_random_states) + + @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) + async def sleep(self): + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + await self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + + self.module.train() + + # add empty cache after each compute + torch.cuda.empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input.""" diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index 4e047e212ff..0018b10a2ad 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -17,6 +17,7 @@ This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. """ +import asyncio import logging import os @@ -89,8 +90,8 @@ def __enter__(self): self.transformer_config, self.layer_name_mapping, ) - self.update_weights(per_tensor_param) - + loop = asyncio.get_event_loop() + loop.run_until_complete(self.update_weights(per_tensor_param)) # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() @@ -99,7 +100,8 @@ def __enter__(self): @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - self.release_memory() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.release_memory()) log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) for model in self.actor_module: @@ -112,9 +114,9 @@ def __exit__(self, exc_type, exc_value, traceback): self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) - def update_weights(self, params): + async def update_weights(self, params): if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.resume_memory_occupation() + await self.inference_engine.resume_memory_occupation() # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update # named_tensors = [(k, v) for k, v in params.items()] @@ -122,7 +124,7 @@ def update_weights(self, params): load_format = None for tensor_index, (name, tensor) in enumerate(named_tensors): if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.update_weights_from_tensor( + await self.inference_engine.update_weights_from_tensor( named_tensors=[ ( name, @@ -134,11 +136,42 @@ def update_weights(self, params): ) if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.flush_cache() + await self.inference_engine.flush_cache() - def release_memory(self): + async def release_memory(self): if self.device_mesh["tp"].get_local_rank() == 0: - self.inference_engine.release_memory_occupation() + await self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) + async def wake_up(self): + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) + await self.update_weights(per_tensor_param) + # important: need to manually set the random states of each tp to be identical. + if self.device_mesh is not None: + self.torch_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.gen_random_states) + + @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) + async def sleep(self): + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + await self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + + for model in self.actor_module: + model.train() + # add empty cache after each compute + torch.cuda.empty_cache() + + # restore random states + if self.device_mesh is not None: + self.gen_random_states = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: