Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
714a32f
feat: implement vLLM co-located training-inference rollout with proce…
jianjunzhong Nov 24, 2025
06959df
add vLLMWorkerProc
jianjunzhong Nov 28, 2025
d1cd582
update
jianjunzhong Dec 2, 2025
b1e954c
update
jianjunzhong Dec 2, 2025
fab21aa
update
jianjunzhong Dec 2, 2025
d5c11be
fix and add test cases
jianjunzhong Dec 2, 2025
c994d67
update
jianjunzhong Dec 2, 2025
a5de2d5
update
jianjunzhong Dec 2, 2025
e956ca0
update
jianjunzhong Dec 2, 2025
4680742
update
jianjunzhong Dec 2, 2025
8a0b4ef
update
jianjunzhong Dec 2, 2025
d2b7c8b
update
jianjunzhong Dec 2, 2025
6b861a5
add 'non_block' arg for _execute_method()
jianjunzhong Dec 3, 2025
685fd2a
support async execute method in vLLMRollout
jianjunzhong Dec 3, 2025
686891d
add lock for _do_execute()
jianjunzhong Dec 3, 2025
8635731
fix and remove redundant codes
jianjunzhong Dec 3, 2025
38971b6
rename vLLMAsyncRollout to ServerAdapter and update class description
jianjunzhong Dec 6, 2025
b684a86
update
jianjunzhong Dec 7, 2025
ca088a2
update
jianjunzhong Dec 7, 2025
87b9843
remove VERL_VLLM_MULTIPROC_RANK_OFFSET, use CUDA_VISIBLE_DEVICES
jianjunzhong Dec 8, 2025
638339e
fix cuda invalid device ordinal
jianjunzhong Dec 11, 2025
6e2e0aa
remove unnessesary docstring
jianjunzhong Dec 11, 2025
64c9bf8
remove vLLMMultiprocExecutor
jianjunzhong Dec 15, 2025
7c23dad
update
jianjunzhong Dec 15, 2025
9ca2e2c
update
jianjunzhong Dec 15, 2025
5b03b23
update
jianjunzhong Dec 15, 2025
68258b5
remove useless codes
jianjunzhong Dec 15, 2025
dfe4e41
remove useless codes
jianjunzhong Dec 16, 2025
7510e73
merge with main
jianjunzhong Dec 16, 2025
fcf779d
update
jianjunzhong Dec 16, 2025
afa9da9
fix pre-commit
jianjunzhong Dec 16, 2025
d427468
fix ci
jianjunzhong Dec 23, 2025
eb6fb52
update weights using buffer
jianjunzhong Dec 24, 2025
8a7c137
Merge branch 'main' into refactor/vllm_sep_proc
jianjunzhong Dec 24, 2025
4207bdf
fix
jianjunzhong Dec 24, 2025
32dbfad
add ascend npu support
jianjunzhong Jan 3, 2026
13444c8
update weight with bucket
wuxibin89 Jan 13, 2026
38c59bf
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 13, 2026
ea31d56
fix vllm
wuxibin89 Jan 13, 2026
b96cc65
fix ci
wuxibin89 Jan 13, 2026
df617af
fix vllm
wuxibin89 Jan 13, 2026
1792dc4
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 14, 2026
09fb51a
fix sglang import
wuxibin89 Jan 14, 2026
dd851fa
fix ascend env
wuxibin89 Jan 14, 2026
6d3d08d
fix ci
wuxibin89 Jan 14, 2026
dc12216
fix ci
wuxibin89 Jan 14, 2026
d87272c
add env vars for multi procs in a single npu, and fix assertion error…
jianjunzhong Jan 16, 2026
ff0c7de
fix NPU env
wuxibin89 Jan 17, 2026
ea0b051
fix cuda ipc memory leak and free transformer_engine wgrad cache
wuxibin89 Jan 17, 2026
37e3eaa
Free Megatron-LM's global memory buffer
wuxibin89 Jan 18, 2026
18e9f1f
fix megatron zero_grad
wuxibin89 Jan 19, 2026
b89e83e
revert megatron_utils.py
wuxibin89 Jan 20, 2026
5c39ca3
revert get_global_memory_buffer() clear
wuxibin89 Jan 20, 2026
cc73dab
VLLM_DISABLE_COMPILE_CACHE=1
wuxibin89 Jan 20, 2026
d33cbb0
set update_weights_bucket_megabytes=2048
wuxibin89 Jan 20, 2026
945f7d5
recover one-step workflow
wuxibin89 Jan 20, 2026
e7a9c17
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 20, 2026
39c36cf
non blocking copy
wuxibin89 Jan 20, 2026
f263e53
fix ci
wuxibin89 Jan 21, 2026
2808f87
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 21, 2026
c7d11fb
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 21, 2026
2bfe6ea
Merge branch 'main' into refactor/vllm_sep_proc
wuxibin89 Jan 23, 2026
b4e6ae3
temporary disable async ci
wuxibin89 Jan 23, 2026
0b19e64
set vllm cudagraph_mode=FULL_DECODE_ONLY
wuxibin89 Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/e2e_ascend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ jobs:
- name: Preprocess gsm8k dataset
run: |
python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k
- name: Running the E2E test with one_step_off_policy algorithm on ASCEND NPU (FSDP2)
run: |
ray stop --force
bash tests/special_npu/run_one_step_off_policy.sh
rm -rf $HOME/ckpts
# TODO(wuxibin): temporary disable until we refactor with checkpoint engine
# - name: Running the E2E test with one_step_off_policy algorithm on ASCEND NPU (FSDP2)
# run: |
# ray stop --force
# bash tests/special_npu/run_one_step_off_policy.sh
# rm -rf $HOME/ckpts
8 changes: 4 additions & 4 deletions .github/workflows/npu_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# - `special_sanity`: a suite of quick sanity tests
# - `special_standalone`: a set of test that are designed to run in dedicated environments

# Accelerators for tests
# Accelerators for tests
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.

Expand Down Expand Up @@ -67,7 +67,7 @@ concurrency:
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}

# Declare permissions just read content.
permissions:
permissions:
contents: read

jobs:
Expand Down Expand Up @@ -109,7 +109,7 @@ jobs:
- name: Run all NPU unit tests
run: |
export PYTHONPATH=$PYTHONPATH:/Megatron-LM
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" --ignore-glob="*test_nccl*" --ignore-glob="*test_nixl*" tests/
pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" --ignore-glob="tests/checkpoint_engine" tests/
- name: Testing FSDP2 actor functionality
run: |
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py
Expand All @@ -118,4 +118,4 @@ jobs:
torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/critic/test_special_dp_critic.py
- name: Running NPU profiling unit tests
run: |
pytest -s -x tests/utils/test_special_mstx_profile.py
pytest -s -x tests/utils/test_special_mstx_profile.py
3 changes: 1 addition & 2 deletions examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,5 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
trainer.total_epochs=15 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@
trainer.total_epochs=15 $@

Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,5 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
trainer.total_epochs=15 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@
trainer.total_epochs=15 $@

Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
trainer.total_epochs=15 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \
actor_rollout_ref.rollout.trace.token2text=False \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.multi_turn.enable=true \
Expand Down
1 change: 0 additions & 1 deletion examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.multi_stage_wake_up=True \
actor_rollout_ref.rollout.multi_turn.enable=True \
Expand Down
1 change: 1 addition & 0 deletions tests/experimental/agent_loop/test_standalone_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def test_standalone_rollout(init_config, tp_size):
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
"NCCL_P2P_DISABLE": "1", # disable p2p in L20
}
}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/special_e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ROLLOUT_MODE="async"
RETURN_RAW_CHAT="True"
SKIP_TOKENIZER_INIT="True"

GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}
ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}
REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}
Expand Down
1 change: 0 additions & 1 deletion tests/special_e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
++actor_rollout_ref.rollout.quantization=${ROLLOUT_QUANTIZATION} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/config/legacy_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ actor_rollout_ref:
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
max_num_seqs: 256
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/config/legacy_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ actor_rollout_ref:
max_model_len: null

# max length of sequences
max_num_seqs: 1024
max_num_seqs: 256

# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.
log_prob_micro_batch_size: null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import requests
from sglang_router.launch_server import RouterArgs, launch_router

from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down
2 changes: 1 addition & 1 deletion verl/experimental/reward_loop/router/naive_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ actor_rollout_ref:
pipeline_model_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
max_num_seqs: 256
enable_chunked_prefill: true
enable_prefix_caching: true
logprobs_mode: processed_logprobs
Expand Down Expand Up @@ -268,7 +268,7 @@ actor_rollout_ref:
_target_: verl.workers.config.CustomAsyncServerConfig
path: null
name: null
update_weights_bucket_megabytes: 512
update_weights_bucket_megabytes: 2048
trace:
_target_: verl.workers.config.TraceConfig
backend: null
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ actor_rollout_ref:
pipeline_model_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
max_num_seqs: 256
enable_chunked_prefill: true
enable_prefix_caching: true
logprobs_mode: processed_logprobs
Expand Down Expand Up @@ -259,7 +259,7 @@ actor_rollout_ref:
_target_: verl.workers.config.CustomAsyncServerConfig
path: null
name: null
update_weights_bucket_megabytes: 512
update_weights_bucket_megabytes: 2048
trace:
_target_: verl.workers.config.TraceConfig
backend: null
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ max_num_batched_tokens: 8192
max_model_len: null

# max length of sequences
max_num_seqs: 1024
max_num_seqs: 256

# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
enable_chunked_prefill: True
Expand Down Expand Up @@ -253,7 +253,7 @@ agent:
# 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`
# 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`
# when using Tensor Parallelism (TP) >= 8.
update_weights_bucket_megabytes: 512
update_weights_bucket_megabytes: 2048

# trace rollout data
trace:
Expand Down
7 changes: 7 additions & 0 deletions verl/trainer/constants_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
# https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues
# https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445
"NCCL_CUMEM_ENABLE": "0",
# TODO: disable compile cache due to cache corruption issue
# https://github.com/vllm-project/vllm/issues/31199
"VLLM_DISABLE_COMPILE_CACHE": "1",
# Needed for multi-processes colocated on same NPU device
# https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0143.html
"HCCL_HOST_SOCKET_PORT_RANGE": "auto",
"HCCL_NPU_SOCKET_PORT_RANGE": "auto",
},
}

Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/runtime_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ excludes: ["/.git/"]
env_vars:
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
CUDA_DEVICE_MAX_CONNECTIONS: "1"
HCCL_HOST_SOCKET_PORT_RANGE: "auto"
HCCL_NPU_SOCKET_PORT_RANGE: "auto"
8 changes: 8 additions & 0 deletions verl/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def is_torch_npu_available(check_device=True) -> bool:
is_npu_available = is_torch_npu_available()


def get_resource_name() -> str:
"""Function that return ray resource name based on the device type.
Returns:
ray resource name string, either "GPU" or "NPU".
"""
return "GPU" if is_cuda_available else "NPU"


def get_visible_devices_keyword() -> str:
"""Get the environment variable name for visible device selection.

Expand Down
13 changes: 13 additions & 0 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,19 @@ def _iter_opts(opt):
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
if "exp_avg_sq" in v:
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)

try:
# Free TransformerEngine's dummy weight gradients cache
# https://github.com/NVIDIA/TransformerEngine/blob/release_v2.10/transformer_engine/pytorch/module/base.py#L64
from transformer_engine.pytorch.module.base import _dummy_wgrads

_dummy_wgrads.clear()
except ImportError:
pass

# Free Megatron-LM's global memory buffer
# get_global_memory_buffer().buffer.clear()

gc.collect()
get_torch_device().empty_cache()

Expand Down
1 change: 1 addition & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,5 +838,6 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals
RouterReplay.clear_global_router_replay_action()
RouterReplay.clear_global_indices()

self.actor_optimizer.zero_grad()
get_torch_device().empty_cache()
return metrics
1 change: 1 addition & 0 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, MegatronEngine)
self.engine.optimizer_zero_grad()
super().__exit__(exc_type, exc_value, traceback)


Expand Down
27 changes: 0 additions & 27 deletions verl/workers/engine_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from contextlib import nullcontext
from functools import partial
from itertools import chain
from typing import Any, Optional

import torch
from codetiming import Timer
Expand Down Expand Up @@ -486,8 +485,6 @@ def init_model(self):
self.set_dispatch_collect(mesh_name="actor", **self.actor.get_dispatch_collect())

# 3. build rollout engine
# - vllm: vLLMAsyncRollout
# - sglang: ServerAdapter
if "rollout" in self.role:
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)

Expand Down Expand Up @@ -595,27 +592,3 @@ async def wake_up(self):
# important: need to manually set the random states of each tp to be identical.
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)

# ============================ vLLM related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def get_zeromq_address(self):
return self.rollout.get_zeromq_address()

# ============================ SGLang related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def chat_completion(self, json_request):
ret = await self.rollout.chat_completion(json_request)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def generate(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
) -> list[int]:
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)
return ret
25 changes: 0 additions & 25 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import warnings
from dataclasses import asdict
from typing import Any, Optional

import numpy as np
import psutil
Expand Down Expand Up @@ -2020,27 +2019,3 @@ async def wake_up(self):
async def sleep(self):
await self.trainer_mode()
return True

# ============================ vLLM related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def get_zeromq_address(self):
return self.rollout.get_zeromq_address()

# ============================ SGLang related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def chat_completion(self, json_request):
ret = await self.rollout.chat_completion(json_request)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def generate(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
) -> list[int]:
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)
return ret
25 changes: 0 additions & 25 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import os
import time
from typing import Any, Optional

import psutil
import torch
Expand Down Expand Up @@ -989,30 +988,6 @@ async def sleep(self):
await self.trainer_mode()
return True

# ============================ vLLM related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def get_zeromq_address(self):
return self.rollout.get_zeromq_address()

# ============================ SGLang related ============================

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def chat_completion(self, json_request):
ret = await self.rollout.chat_completion(json_request)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def generate(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
) -> list[int]:
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)
return ret


class CriticWorker(MegatronWorker, DistProfilerExtension):
def __init__(self, config: McoreCriticConfig):
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:


_ROLLOUT_REGISTRY = {
("vllm", "async"): "verl.workers.rollout.vllm_rollout.vLLMAsyncRollout",
("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter",
("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter",
("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter",
}
Expand Down
Loading