Skip to content

Conversation

@CharlieFRuan
Copy link
Collaborator

@CharlieFRuan CharlieFRuan commented Jul 8, 2025

Overview

Before this PR, we can only use SGLang as a backend to generate rollout as a remote server (see sglang_server.py).

This PR implements sglang_engine.py to allow using SGLang locally (e.g. colocate with the policy model).

We bump SGLang to 0.4.8.post1 for now. Bumping to 0.4.9.post1 causes weight sync to hang when not colocated (but using local engine) -- i.e. the test no_colocate_nccl_fsdp2_sglang in test_policy_local_engines_e2e.py would fail. 0.4.8.post1 already supports two-stage wake up: sgl-project/sglang#7099

Currently, we still cannot support TP > 1 with the local engines and leave it as a future TODO.

Three quirks

  1. We use a remote task get_sglang_engine() to create SGLangInferenceEngine, since we need a GPU to import SGLang, otherwise sglang will try to import vllm, making dependencies management a bit messy
  2. To support weight sync via CUDA IPC, we need to write per-tp-worker code. Since SGLang does not support worker-extension-cls like vLLM does, the only way I found is to use custom_weight_loader. We base64 encode the ipc handles into a tensor and reuse SGLang's update_weights_from_tensor().
  3. SGLang currently cannot sleep, wake up, and start generating. They have to do explicit weight sync, hence the no_sync parameter change in eval_weights_manager ([Bug][sleep] Create engine, sleep, wake up, generate --> gibberish sgl-project/sglang#7939)

Tests

  • Parametrized the test_policy_vllm_e2e.py to also run with SGLang, and renamed the test as a result. This test covers instantiating the engine, sleep, wake up, weight sync, then generate. We also test with different config combinations.
  • Parametrized the test_engine_generation.py which tests both remote sglang and local sglang.
  • See E2E results below too

Future TODO

  • Support TP > 1 for the non-remote SGLang engines, reaching parity with non-remote vLLM engines

E2E run_gsm8k.sh on 4xH100

Did four runs: for each of vLLM and SGLang, did non-colocated (2 TP=1 engines for inference, 2 for training), and colocated (4 TP=1 engines for inference, 4 for training).
Performance
image

Metrics
image

"""Update named weights in SGLang engine."""
extras = request.get("extras")
if extras is not None and "ipc_handles" in extras:
# CUDA IPC -- Here we reuse SGLang's update_weights_from_tensor, but actually load the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CUDA IPC weight syncing, an alternative is to use SGLang's Engine.collective_rpc(), and patch a update_weights_cuda_ipc() to sgl's Scheduler (entity that carries out the RPC), mimicing our vLLM's implementation.

Tried it, but I found it hard to patch a method to SGLang's Scheduler without native support like vLLM's worker_extension_cls, since sgl instantiates Scheduler in subprocess and it can easily lose what we patched in the main process. Modifying source code might work, but likely the current solution is better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think this current solution is better than patching. Let's now chat with SGLang folks to get their input in more detail on enabling per-TP worker code

deepspeed = [
"deepspeed==0.16.5"
]
cpu_ci_test = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added for CPU tests like test_models.py which requires flash attention. We can remove this once vllm and sglang depend on the same flash attention (hence moved back to the main dependencies rather than in each extra).

@CharlieFRuan CharlieFRuan changed the title [Generator] Add initial support for non-remote SGLang engine [Generator] Support non-remote (e.g. colocated) SGLang engine Jul 12, 2025
@tyler-griggs tyler-griggs self-requested a review July 13, 2025 20:21
Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a first round of comments

]
vllm = [
"vllm==0.8.5",
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to momentarily ignore the pyproject file in this first pass of reviews because I assume it will be rebased on #73



# TODO(charlie): duplicate of setup_envvars_for_vllm, is it needed?
def setup_envvars_for_sglang(kwargs, bundle_indices):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, distributed_executor_backend is specific to vLLM's configuration. Also, it should only matter for TP>1.

For noset_visible_devices, I have a feeling this also would only matter for TP>1 so it's hard to test now whether something like this will be needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! I'll leave it here and address when we support TP > 1, would that be fine? It does not seem to affect current cases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, not sure if you saw it, but we did have a similar test for sglang under tests/sglang. It seems like your changes here replace it and we should delete the tests/sglang folder. We originally split sglang into its own folder so we could just run uv run --isolated --extra dev --extra vllm pytest tests/gpu and run all tests without having to separately run vllm and sglang. I don't know the right way to structure this long-term, but for now we rarely run all gpu tests like this and more often we manually choose some subset of tests, so I think it's fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I'll delete tests/sglang. And what you said about being able to run an entire folder with --extra vllm makes sense. We could structure our future tests with marks=pytest.mark.sglang and pytest.mark.vllm and that should help keep both groups of tests in the same folder while able to run an entire folder.

@CharlieFRuan CharlieFRuan force-pushed the pr-0707-sglang-non-remote branch from 83844a7 to 55d58c0 Compare July 14, 2025 21:47
@CharlieFRuan
Copy link
Collaborator Author

@tyler-griggs addressed comments, rebased to main, and tested with GPU tests and run_gsm8k.sh. Ready for another round of review:)

Comment on lines +160 to +165
before_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from skyrl_train.inference_engines.sglang.sglang_engine import SGLangRayActor

os.environ["CUDA_VISIBLE_DEVICES"] = before_cuda_visible_devices

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like another issue with CUDA_VISIBLE_DEVICES patching in Ray for num_gpus=0 cc @pcmoritz

Copy link
Collaborator

@pcmoritz pcmoritz Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are going to fix it, see ray-project/ray#54868

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

request_tensor = [("ipc_request", request_tensor)]
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=[
MultiprocessingSerializer.serialize(request_tensor) for _ in range(self._tp_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: I don't know much about weight sync in sglang. Does MultiprocessingSerializer.serialize take tensor on GPU and serialize it via pickle in sglang? If so, I'm wondering why the conversion base64 encoding + .cuda op was needed?

Copy link
Collaborator Author

@CharlieFRuan CharlieFRuan Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The serialize() will have to take in a list of (str, torch.Tensor) (i.e. named tensors) so that it matches what the codepath in SGLang expects when they deserialize it
  • We want to pass our request (which includes ipc handle info) down to the codepath so our customized update_weight_cuda_ipc() can handle it in a per-TP-worker fashion
  • However our request is not a tensor, so we base64 encoded it and store in a torch.unit8 Tensor

Not sure if I answered your question...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay got it.

I am wondering why the tensor needs to be on gpu though? I am not too familiar with sglang internals, but it looks like the Engine class runs the tokenizer manager and the TP workers are in a different process. Does MultiprocessingSerializer.serialize expect a gpu tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes you are right, it should be a CPU tensor. Changed and tested. Thanks for the catch!

@SumanthRH
Copy link
Member

cc @tyler-griggs Let's try to get this PR in today

@SumanthRH
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for using SGLang as a non-remote (colocated) inference engine. The changes include a new sglang_engine.py implementation with workarounds for SGLang's limitations, updates to the engine creation logic to support both vLLM and SGLang backends, and extensive refactoring of tests to be parameterized and cover both backends. My review focuses on improving robustness, maintainability, and cleaning up debugging artifacts. Key suggestions include ensuring correct GPU scheduling for Ray tasks, making data serialization more robust, and replacing debug print statements with structured logging.

Comment on lines +88 to +89
# TODO(charlie): duplicate of setup_envvars_for_vllm, is it needed?
def setup_envvars_for_sglang(kwargs, bundle_indices):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment is valid. The function setup_envvars_for_sglang is very similar to setup_envvars_for_vllm in skyrl_train/inference_engines/vllm/vllm_engine.py. To improve maintainability and reduce code duplication, these two functions should be refactored into a single, shared utility function.

Comment on lines +124 to +127
end_marker = b"__END_OF_REQUEST__"
end_index = tensor_bytes.find(end_marker)
if end_index == -1:
raise ValueError("End marker not found in tensor data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Left a nit - trying to understand why sglang import happens in a separate ray task - but beyond that LGTM.

Some CPU tests are failing. Let's make sure CPU tests and the new GPU tests pass.

@CharlieFRuan CharlieFRuan force-pushed the pr-0707-sglang-non-remote branch from 933ecaf to 26ac831 Compare August 12, 2025 07:54
@CharlieFRuan
Copy link
Collaborator Author

This is ready again. Able to run the entire skyrl-train/tests/gpu/test_policy_local_engines_e2e.py. The CPU test is fixed by itself for some reason (not sure if flaky or fixed unintentionally), was some sentencepiece/tiktoken missing dependency issue.

@CharlieFRuan
Copy link
Collaborator Author

current 1.5B Qwen2.5 gsm8k run on 4xH1001
image

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving one open thread to @SumanthRH , otherwise approve!

@SumanthRH SumanthRH merged commit 34e06da into NovaSky-AI:main Aug 12, 2025
3 checks passed
@CharlieFRuan CharlieFRuan deleted the pr-0707-sglang-non-remote branch August 12, 2025 22:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants