Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
dd44871
[Generator] Add initial support for non-remote SGLang engine
CharlieFRuan Jul 8, 2025
d8511cc
fix lint
CharlieFRuan Jul 8, 2025
aa23222
support CUDA IPC weight sync
CharlieFRuan Jul 9, 2025
a591da4
update bash script
CharlieFRuan Jul 9, 2025
b5e8ecd
fix gibberish output in eval
CharlieFRuan Jul 10, 2025
b38f6d2
add unit test for local sglang engine
CharlieFRuan Jul 11, 2025
3487d21
Add more tests
CharlieFRuan Jul 11, 2025
d24e601
fix lint
CharlieFRuan Jul 11, 2025
ef3303b
trivial
CharlieFRuan Jul 11, 2025
a6b406e
trivial
CharlieFRuan Jul 11, 2025
a9612fc
trivial
CharlieFRuan Jul 11, 2025
aae58a2
fix CI
CharlieFRuan Jul 12, 2025
7cb62dc
fix CI
CharlieFRuan Jul 12, 2025
2dc1215
fix CI
CharlieFRuan Jul 12, 2025
94d8608
fix CI
CharlieFRuan Jul 12, 2025
6009589
Address comments
CharlieFRuan Jul 14, 2025
55d58c0
remove sglang test folder
CharlieFRuan Jul 14, 2025
673132f
remove cpu ci test
CharlieFRuan Jul 14, 2025
cac9eb7
revert uv.lock change
CharlieFRuan Jul 14, 2025
aa90146
fix vllm gpu test
CharlieFRuan Jul 14, 2025
172eea7
fix dtype error
CharlieFRuan Jul 14, 2025
b0ad3a2
Pass in CPU tensor to MultiprocessingSerializer.serialize() instead
CharlieFRuan Jul 23, 2025
977f6f4
Merge branch 'main' into pr-0707-sglang-non-remote
CharlieFRuan Aug 8, 2025
c2e2209
fix gpu tests
CharlieFRuan Aug 11, 2025
f17619c
Merge branch 'main' into pr-0707-sglang-non-remote
CharlieFRuan Aug 11, 2025
0b5978e
Merge branch 'main' into pr-0707-sglang-non-remote
CharlieFRuan Aug 12, 2025
26ac831
Fix SGLang with NCCL_CUMEM_ENABLE
CharlieFRuan Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions skyrl-train/examples/gsm8k/run_gsm8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout

uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
INFERENCE_BACKEND="vllm"
# INFERENCE_BACKEND="sglang"

uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
Expand All @@ -37,7 +40,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=vllm \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
Expand Down
4 changes: 2 additions & 2 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dev = [
"black==24.10.0",
"pytest>=6.2.5",
"pytest-asyncio",
"pre-commit"
"pre-commit",
]
docs = [
"sphinx>=7.0.0",
Expand All @@ -82,7 +82,7 @@ vllm = [
"torchvision"
]
sglang = [
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1",
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1", # 0.4.9.post1 causes non-colocate weight broadcast to hang
# The version is pinned to 0.2.5 because sglang requires this
# NOTE (sumanthrh): This can be made a common dependency, but then different inference engines can pin different compatible flashinfer versions and it might quickly break.
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl",
Expand Down
3 changes: 2 additions & 1 deletion skyrl-train/skyrl_train/entrypoints/main_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ def create_ray_wrapped_inference_engines_from_config(cfg: DictConfig, colocate_p
max_model_len=cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length,
shared_pg=colocate_pg,
gpu_memory_utilization=cfg.generator.gpu_memory_utilization,
vllm_enable_sleep=cfg.trainer.placement.colocate_all,
inference_engine_enable_sleep=cfg.trainer.placement.colocate_all,
async_engine=cfg.generator.async_engine,
max_num_batched_tokens=cfg.generator.max_num_batched_tokens,
max_num_seqs=cfg.generator.max_num_seqs,
sampling_params=get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params),
tokenizer=tokenizer,
backend=cfg.generator.backend,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def tp_size(self):
return ray.get(self.inference_engine_actor.tp_size.remote())

async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
return await self.inference_engine_actor.generate.remote(input_batch)
return await self.inference_engine_actor.generate.remote(input_batch=input_batch)

async def wake_up(self, *args: Any, **kwargs: Any):
return await self.inference_engine_actor.wake_up.remote(*args, **kwargs)
Expand Down Expand Up @@ -62,21 +62,29 @@ def create_ray_wrapped_inference_engines(
max_model_len: int,
shared_pg=None,
gpu_memory_utilization=None,
vllm_enable_sleep=False,
inference_engine_enable_sleep=False,
async_engine=False,
max_num_batched_tokens=8192,
max_num_seqs=1024,
sampling_params: Optional[Dict[str, Any]] = None,
tokenizer=None,
backend="vllm",
) -> List[InferenceEngineInterface]:
"""
Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances.
"""
import vllm
from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor
from skyrl_train.utils import ray_noset_visible_devices, get_all_env_variables, get_ray_pg_ready_with_timeout

assert vllm.__version__ >= "0.8.3", "SkyTrainer only supports vLLM >= 0.8.3"
if backend == "vllm":
import vllm
from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor

assert vllm.__version__ >= "0.8.3", "SkyTrainer only supports vLLM >= 0.8.3"
elif backend == "sglang":
# We import SGLang later to avoid importing vllm. See `get_sglang_engine` for more.
pass
else:
raise ValueError(f"Unsupported backend: {backend}")
inference_engine_actors = []
noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote()))
# NOTE: we use the ray backend for tensor parallel size > 1 to explicitly manage resource allocation
Expand Down Expand Up @@ -106,42 +114,92 @@ def create_ray_wrapped_inference_engines(
placement_group_bundle_index=i * tensor_parallel_size,
)

if async_engine:
actor_class = AsyncVLLMRayActor
else:
actor_class = VLLMRayActor

vllm_engine = actor_class.options(
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
model=pretrain,
enforce_eager=enforce_eager,
worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap",
tensor_parallel_size=tensor_parallel_size,
seed=seed + i,
distributed_executor_backend=distributed_executor_backend,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
dtype=model_dtype,
trust_remote_code=True,
vllm_v1_disable_multiproc=vllm_v1_disable_multiproc,
gpu_memory_utilization=gpu_memory_utilization,
bundle_indices=bundle_indices,
num_gpus=0.2 if use_hybrid_engine else 1,
enable_sleep_mode=vllm_enable_sleep,
noset_visible_devices=noset_visible_devices,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
sampling_params=sampling_params,
tokenizer=tokenizer,
)
inference_engine_actors.append(vllm_engine)
if backend == "vllm":
if async_engine:
actor_class = AsyncVLLMRayActor
else:
actor_class = VLLMRayActor

engine = actor_class.options(
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
model=pretrain,
enforce_eager=enforce_eager,
worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap",
tensor_parallel_size=tensor_parallel_size,
seed=seed + i,
distributed_executor_backend=distributed_executor_backend,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
dtype=model_dtype,
trust_remote_code=True,
vllm_v1_disable_multiproc=vllm_v1_disable_multiproc,
gpu_memory_utilization=gpu_memory_utilization,
bundle_indices=bundle_indices,
num_gpus=0.2 if use_hybrid_engine else 1,
enable_sleep_mode=inference_engine_enable_sleep,
noset_visible_devices=noset_visible_devices,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
sampling_params=sampling_params,
tokenizer=tokenizer,
)
elif backend == "sglang":
# NOTE: there is no async / sync engine distinction in SGLang

# NOTE(Charlie): We need `torch.cuda.is_available()` to be True to import SGLang. Otherwise, it requires
# importing vllm. See https://github.com/sgl-project/sglang/blob/v0.4.8.post1/python/sglang/srt/layers/quantization/utils.py#L11-L17
# Similar comment: https://github.com/volcengine/verl/blob/9cc307767b0c787e8f5ef581dac929f7bde044ef/verl/workers/fsdp_workers.py#L520-L527
@ray.remote
def get_sglang_engine():
# A workaround to avoid importing vllm is to give this task a GPU.
import os

before_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from skyrl_train.inference_engines.sglang.sglang_engine import SGLangRayActor

os.environ["CUDA_VISIBLE_DEVICES"] = before_cuda_visible_devices

Comment on lines +162 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like another issue with CUDA_VISIBLE_DEVICES patching in Ray for num_gpus=0 cc @pcmoritz

Copy link
Collaborator

@pcmoritz pcmoritz Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are going to fix it, see ray-project/ray#54868

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

actor_class = SGLangRayActor
engine = actor_class.options(
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
model_path=pretrain,
tp_size=tensor_parallel_size,
mem_fraction_static=gpu_memory_utilization,
random_seed=seed + i,
context_length=max_model_len,
disable_radix_cache=not enable_prefix_caching,
dtype=model_dtype,
trust_remote_code=True,
max_prefill_tokens=max_num_batched_tokens,
max_running_requests=max_num_seqs,
# Borrowed from veRL's SGLang rollout
mm_attention_backend="fa3",
attention_backend="fa3",
enable_memory_saver=inference_engine_enable_sleep,
# Will be popped before instantiating sgl.Engine
distributed_executor_backend=distributed_executor_backend,
noset_visible_devices=noset_visible_devices,
bundle_indices=bundle_indices,
num_gpus=0.2 if use_hybrid_engine else 1,
sampling_params=sampling_params,
tokenizer=tokenizer,
)
return engine

engine = ray.get(get_sglang_engine.remote())

inference_engine_actors.append(engine)

engines = [RayWrappedInferenceEngine(actor_handle) for actor_handle in inference_engine_actors]

if vllm_enable_sleep:
if inference_engine_enable_sleep:
sleep_refs = [engine.inference_engine_actor.sleep.remote() for engine in engines]
ray.get(sleep_refs)

Expand Down
Loading