Skip to content

Commit 333e6d6

Browse files
[rollout] feat: add SGLang as rollout engine to verl (#490)
#22 . WIP, will add more details tomorrow :) --------- Co-authored-by: zhaochenyang20 <[email protected]>
1 parent 3b18b0e commit 333e6d6

File tree

19 files changed

+1177
-25
lines changed

19 files changed

+1177
-25
lines changed
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
name: e2e_sglang_gsm8k
2+
3+
on:
4+
# Trigger the workflow on push or pull request,
5+
# but only for the main branch
6+
push:
7+
branches:
8+
- main
9+
- v0.2.x
10+
paths:
11+
- "**/*.py"
12+
- .github/workflows/e2e_sglang_gsm8k.yml
13+
pull_request:
14+
branches:
15+
- main
16+
- v0.2.x
17+
paths:
18+
- "**/*.py"
19+
- "verl/trainer/config/*.yaml"
20+
- .github/workflows/e2e_sglang_gsm8k.yml
21+
- "tests/e2e/*.sh"
22+
23+
# Cancel jobs on the same ref if a new one is triggered
24+
concurrency:
25+
group: ${{ github.workflow }}-${{ github.ref }}
26+
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
27+
28+
# Declare permissions just read content.
29+
permissions:
30+
contents: read
31+
32+
jobs:
33+
e2e_sglang_gsm8k:
34+
runs-on: [self-hosted, l20-1]
35+
timeout-minutes: 40 # Increase this timeout value as needed
36+
env:
37+
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
38+
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
39+
NO_PROXY: "localhost,127.0.0.1"
40+
HF_HUB_ENABLE_HF_TRANSFER: 1
41+
container:
42+
image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.3.post3
43+
options: --gpus all --shm-size=10g
44+
steps:
45+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
46+
with:
47+
fetch-depth: 0
48+
- name: Install the current repository
49+
run: |
50+
pip3 install hf_transfer
51+
pip3 install -e .[test,gpu,sglang] --no-deps
52+
- name: Prepare gsm8k dataset
53+
run: |
54+
ray stop --force
55+
python3 examples/data_preprocess/gsm8k.py
56+
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm and save ckpt
57+
run: |
58+
ray stop --force
59+
bash tests/e2e/run_qwen_gsm8k_function_rm.sh sglang
60+

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ celerybeat-schedule
9393

9494
# virtualenv
9595
venv/
96+
.venv/
9697
ENV/
9798

9899
# Spyder project settings
@@ -122,4 +123,5 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
122123

123124
# local logs
124125
logs
125-
log
126+
log
127+
outputs

docs/start/install.rst

+21-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Requirements
1010
verl supports various backends. Currently, the following configurations are available:
1111

1212
- **FSDP** and **Megatron-LM** (optional) for training.
13-
- **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon.
13+
- **SGLang**, **vLLM** and **TGI** for rollout generation.
1414

1515
Training backends
1616
------------------
@@ -19,6 +19,25 @@ We recommend using **FSDP** backend to investigate, research and prototype diffe
1919

2020
For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support Megatron-LM v0.4 [1]_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`.
2121

22+
Install verl-SGLang from scratch
23+
-------------------------------------
24+
25+
**SGLang has largely support the rearch and inference workload at xAI. For verl-sglang installation, ignore the version conflicts reported by pip with vllm. And, SGLang support native API for RLHF, do not need to patch a single line of code.**
26+
27+
The following steps are quick installation guide for verl-SGLang.
28+
29+
.. code:: bash
30+
# Create a virtual environment and use uv for quick installation
31+
python3 -m venv ~/.python/verl-sglang && source ~/.python/verl-sglang/bin/activate
32+
python3 -m pip install --upgrade pip && python3 -m pip install --upgrade uv
33+
34+
# Install verl-SGLang
35+
git clone https://github.com/volcengine/verl verl-sglang && cd verl-sglang
36+
python3 -m uv pip install .
37+
38+
# Install the latest stable version of sglang with verl support, currently, the latest version is 0.4.3.post3
39+
# For SGLang installation, you can also refer to https://docs.sglang.ai/start/install.html
40+
python3 -m uv pip install "sglang[all]==0.4.3.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
2241
2342
Install from docker image
2443
-------------------------
@@ -73,6 +92,7 @@ Image and tag: ``whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-
7392
git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM
7493
export PYTHONPATH=$PYTHONPATH:$(pwd)/Megatron-LM
7594
95+
7696
Install from custom environment
7797
---------------------------------
7898

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ test = [
5757
]
5858
prime = ["pyext"]
5959
gpu = ["liger-kernel", "flash-attn"]
60+
sglang = ["sglang[all]==0.4.3.post3"]
6061

6162
# URLs
6263
[project.urls]

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ ray[default]
1717
tensordict<0.6
1818
torchdata
1919
transformers
20-
vllm<=0.6.3
20+
# vllm==0.6.3.post1
2121
wandb

tests/e2e/run_qwen_gsm8k_function_rm.sh

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set -x
2-
2+
ENGINE=${1:-vllm}
33
export VLLM_ATTENTION_BACKEND=XFORMERS
44

55
python3 -m verl.trainer.main_ppo \
@@ -17,7 +17,7 @@ python3 -m verl.trainer.main_ppo \
1717
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
1818
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
1919
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
20-
actor_rollout_ref.rollout.name=vllm \
20+
actor_rollout_ref.rollout.name=$ENGINE \
2121
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
2222
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
2323
actor_rollout_ref.ref.fsdp_config.param_offload=True \
@@ -36,5 +36,5 @@ python3 -m verl.trainer.main_ppo \
3636
trainer.n_gpus_per_node=8 \
3737
trainer.nnodes=1 \
3838
trainer.save_freq=1 \
39-
trainer.default_local_dir=$HOME/ckpt/ \
40-
trainer.total_training_steps=1 $@
39+
trainer.default_local_dir=$HOME/$ENGINE/ckpt/ \
40+
trainer.total_training_steps=1

tests/rollout/test_sglang_spmd.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
15+
#
16+
# Licensed under the Apache License, Version 2.0 (the "License");
17+
# you may not use this file except in compliance with the License.
18+
# You may obtain a copy of the License at
19+
#
20+
# http://www.apache.org/licenses/LICENSE-2.0
21+
#
22+
# Unless required by applicable law or agreed to in writing, software
23+
# distributed under the License is distributed on an "AS IS" BASIS,
24+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+
# See the License for the specific language governing permissions and
26+
# limitations under the License.
27+
28+
import os
29+
import torch
30+
from torch.distributed.device_mesh import init_device_mesh
31+
32+
from sglang.srt.entrypoints.verl_engine import VerlEngine
33+
34+
from transformers import AutoTokenizer, AutoModelForCausalLM
35+
from transformers import GenerationConfig
36+
37+
from verl.utils.torch_functional import pad_sequence_to_length
38+
39+
40+
def levenshtein(s1, s2):
41+
m, n = len(s1), len(s2)
42+
# Initialize matrix of zeros
43+
dp = [[0] * (n + 1) for _ in range(m + 1)]
44+
# Initialize first column and first row of the matrix
45+
for i in range(m + 1):
46+
dp[i][0] = i # Deletion from s1 to empty string
47+
for j in range(n + 1):
48+
dp[0][j] = j # Insertion to s1 from empty string
49+
# Compute the Levenshtein distance matrix
50+
for i in range(1, m + 1):
51+
for j in range(1, n + 1):
52+
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
53+
dp[i][j] = min(
54+
dp[i - 1][j] + 1, # Deletion
55+
dp[i][j - 1] + 1, # Insertion
56+
dp[i - 1][j - 1] + cost # Substitution
57+
)
58+
return dp[m][n]
59+
60+
61+
def are_lists_similar(a, b):
62+
if len(a) != len(b):
63+
print("The lists are of different lengths.")
64+
return False
65+
66+
total_length = 0
67+
total_diff = 0
68+
69+
for s1, s2 in zip(a, b):
70+
max_len = max(len(s1), len(s2))
71+
total_length += max_len
72+
diff = levenshtein(s1, s2)
73+
total_diff += diff
74+
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")
75+
76+
percentage_difference = (total_diff / total_length) * 100
77+
print(f"Total difference: {percentage_difference:.2f}%")
78+
79+
return percentage_difference <= 10
80+
81+
82+
def initialize_global_process_group(timeout_second=36000):
83+
from datetime import timedelta
84+
85+
import torch.distributed
86+
87+
# NOTE MODIFIED should provide backend=None to have nccl+gloo
88+
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
89+
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))
90+
91+
local_rank = int(os.environ["LOCAL_RANK"])
92+
rank = int(os.environ["RANK"])
93+
world_size = int(os.environ["WORLD_SIZE"])
94+
95+
if torch.distributed.is_initialized():
96+
torch.cuda.set_device(local_rank)
97+
return local_rank, rank, world_size
98+
99+
100+
def test_sglang_spmd():
101+
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'
102+
initialize_global_process_group()
103+
# fill rollout config
104+
max_prompt_length = 16
105+
max_response_length = 16
106+
107+
# Initialize model and token
108+
local_cache_path = '~/.cache/verl/rlhf'
109+
local_cache_path = os.path.expanduser(local_cache_path)
110+
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
111+
from verl.utils.fs import copy_to_local
112+
local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path)
113+
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')
114+
115+
preencode_prompts = [
116+
"Who won the Champions League in 2019?",
117+
"The founder of Apple is",
118+
"What's your name",
119+
]
120+
tokenizer.pad_token = tokenizer.eos_token
121+
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
122+
input_ids = prompts['input_ids']
123+
attention_mask = prompts['attention_mask']
124+
125+
input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
126+
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)
127+
128+
actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
129+
actor_model.to(torch.bfloat16)
130+
131+
sampling_params = dict(n=1,
132+
temperature=0,
133+
top_p=1,
134+
top_k=-1,
135+
max_new_tokens=max_response_length,
136+
presence_penalty=0.0,
137+
frequency_penalty=0.0,
138+
repetition_penalty=1.0,
139+
skip_special_tokens=True,
140+
spaces_between_special_tokens=True,
141+
ignore_eos=False)
142+
143+
tensor_parallel_size = 4
144+
device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
145+
inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
146+
147+
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
148+
if k in os.environ:
149+
del os.environ[k]
150+
print('building sglang rollout engine')
151+
llm = VerlEngine(model_path=local_model_path,
152+
dtype="bfloat16",
153+
mem_fraction_static=0.5,
154+
device_mesh_cpu=inference_device_mesh_cpu["tp"],
155+
base_gpu_id=0,
156+
gpu_id_step=1)
157+
158+
llm.release_memory_occupation()
159+
print("start generation")
160+
input_ids = input_ids.cuda()
161+
attention_mask = attention_mask.cuda()
162+
batch_size = input_ids.size(0)
163+
164+
generation_config = GenerationConfig(do_sample=False)
165+
actor_model.cuda()
166+
output = actor_model.generate(
167+
input_ids=input_ids,
168+
attention_mask=attention_mask,
169+
max_new_tokens=max_response_length,
170+
# max_length=max_length,
171+
eos_token_id=tokenizer.eos_token_id,
172+
pad_token_id=tokenizer.pad_token_id,
173+
generation_config=generation_config,
174+
# renormalize_logits=True,
175+
output_scores=False, # this is potentially very large
176+
return_dict_in_generate=True,
177+
use_cache=False) # may OOM when use_cache = True
178+
seq = output.sequences
179+
response = seq[:, max_prompt_length:]
180+
181+
hf_response_tokens = tokenizer.batch_decode(response)
182+
print(f"hf response: {hf_response_tokens}")
183+
print(f"{sampling_params=}")
184+
idx_list = []
185+
batch_size = input_ids.shape[0]
186+
187+
pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
188+
for i in range(batch_size):
189+
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
190+
191+
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)
192+
sglang_response_tokens = []
193+
194+
for output in outputs:
195+
print(f"{output=}")
196+
generated_text = output["text"]
197+
sglang_response_tokens.append(generated_text)
198+
199+
print(f"sglang response: {sglang_response_tokens}")
200+
assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \
201+
f"Strings differ more than 10%:\n"
202+
print("Check Pass")
203+
204+
205+
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
206+
# remove the left padding in the prompt token_id
207+
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
208+
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
209+
token_ids = prompt_token_ids[non_pad_index:].tolist()
210+
return token_ids

verl/single_controller/ray/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def __init__(self):
465465
for key, user_defined_cls in cls_dict.items():
466466
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
467467
# directly instantiate the class without remote
468+
# in worker class, e.g. <verl.single_controller.base.worker.Worker> when DISABLE_WORKER_INIT == 1 it will return immediately
468469
with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):
469470
self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()),
470471
**init_args_dict[key].get('kwargs', {}))

0 commit comments

Comments
 (0)