-
Notifications
You must be signed in to change notification settings - Fork 167
[Generator] Support non-remote (e.g. colocated) SGLang engine #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
dd44871
d8511cc
aa23222
a591da4
b5e8ecd
b38f6d2
3487d21
d24e601
ef3303b
a6b406e
a9612fc
aae58a2
7cb62dc
2dc1215
94d8608
6009589
55d58c0
673132f
cac9eb7
aa90146
172eea7
b0ad3a2
977f6f4
c2e2209
f17619c
0b5978e
26ac831
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ def tp_size(self): | |
| return ray.get(self.inference_engine_actor.tp_size.remote()) | ||
|
|
||
| async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: | ||
| return await self.inference_engine_actor.generate.remote(input_batch) | ||
| return await self.inference_engine_actor.generate.remote(input_batch=input_batch) | ||
|
|
||
| async def wake_up(self, *args: Any, **kwargs: Any): | ||
| return await self.inference_engine_actor.wake_up.remote(*args, **kwargs) | ||
|
|
@@ -62,21 +62,29 @@ def create_ray_wrapped_inference_engines( | |
| max_model_len: int, | ||
| shared_pg=None, | ||
| gpu_memory_utilization=None, | ||
| vllm_enable_sleep=False, | ||
| inference_engine_enable_sleep=False, | ||
| async_engine=False, | ||
| max_num_batched_tokens=8192, | ||
| max_num_seqs=1024, | ||
| sampling_params: Optional[Dict[str, Any]] = None, | ||
| tokenizer=None, | ||
| backend="vllm", | ||
| ) -> List[InferenceEngineInterface]: | ||
| """ | ||
| Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances. | ||
| """ | ||
| import vllm | ||
| from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor | ||
| from skyrl_train.utils import ray_noset_visible_devices, get_all_env_variables, get_ray_pg_ready_with_timeout | ||
|
|
||
| assert vllm.__version__ >= "0.8.3", "SkyTrainer only supports vLLM >= 0.8.3" | ||
| if backend == "vllm": | ||
| import vllm | ||
| from skyrl_train.inference_engines.vllm.vllm_engine import VLLMRayActor, AsyncVLLMRayActor | ||
|
|
||
| assert vllm.__version__ >= "0.8.3", "SkyTrainer only supports vLLM >= 0.8.3" | ||
| elif backend == "sglang": | ||
| # We import SGLang later to avoid importing vllm. See `get_sglang_engine` for more. | ||
| pass | ||
| else: | ||
| raise ValueError(f"Unsupported backend: {backend}") | ||
| inference_engine_actors = [] | ||
| noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote())) | ||
| # NOTE: we use the ray backend for tensor parallel size > 1 to explicitly manage resource allocation | ||
|
|
@@ -106,42 +114,92 @@ def create_ray_wrapped_inference_engines( | |
| placement_group_bundle_index=i * tensor_parallel_size, | ||
| ) | ||
|
|
||
| if async_engine: | ||
| actor_class = AsyncVLLMRayActor | ||
| else: | ||
| actor_class = VLLMRayActor | ||
|
|
||
| vllm_engine = actor_class.options( | ||
| num_cpus=num_gpus, | ||
| num_gpus=num_gpus, | ||
| scheduling_strategy=scheduling_strategy, | ||
| ).remote( | ||
| model=pretrain, | ||
| enforce_eager=enforce_eager, | ||
| worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| seed=seed + i, | ||
| distributed_executor_backend=distributed_executor_backend, | ||
| max_model_len=max_model_len, | ||
| enable_prefix_caching=enable_prefix_caching, | ||
| dtype=model_dtype, | ||
| trust_remote_code=True, | ||
| vllm_v1_disable_multiproc=vllm_v1_disable_multiproc, | ||
| gpu_memory_utilization=gpu_memory_utilization, | ||
| bundle_indices=bundle_indices, | ||
| num_gpus=0.2 if use_hybrid_engine else 1, | ||
| enable_sleep_mode=vllm_enable_sleep, | ||
| noset_visible_devices=noset_visible_devices, | ||
| max_num_batched_tokens=max_num_batched_tokens, | ||
| max_num_seqs=max_num_seqs, | ||
| sampling_params=sampling_params, | ||
| tokenizer=tokenizer, | ||
| ) | ||
| inference_engine_actors.append(vllm_engine) | ||
| if backend == "vllm": | ||
| if async_engine: | ||
| actor_class = AsyncVLLMRayActor | ||
| else: | ||
| actor_class = VLLMRayActor | ||
|
|
||
| engine = actor_class.options( | ||
| num_cpus=num_gpus, | ||
| num_gpus=num_gpus, | ||
| scheduling_strategy=scheduling_strategy, | ||
| ).remote( | ||
| model=pretrain, | ||
| enforce_eager=enforce_eager, | ||
| worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| seed=seed + i, | ||
| distributed_executor_backend=distributed_executor_backend, | ||
| max_model_len=max_model_len, | ||
| enable_prefix_caching=enable_prefix_caching, | ||
| dtype=model_dtype, | ||
| trust_remote_code=True, | ||
| vllm_v1_disable_multiproc=vllm_v1_disable_multiproc, | ||
| gpu_memory_utilization=gpu_memory_utilization, | ||
| bundle_indices=bundle_indices, | ||
| num_gpus=0.2 if use_hybrid_engine else 1, | ||
| enable_sleep_mode=inference_engine_enable_sleep, | ||
| noset_visible_devices=noset_visible_devices, | ||
| max_num_batched_tokens=max_num_batched_tokens, | ||
| max_num_seqs=max_num_seqs, | ||
| sampling_params=sampling_params, | ||
| tokenizer=tokenizer, | ||
| ) | ||
| elif backend == "sglang": | ||
| # NOTE: there is no async / sync engine distinction in SGLang | ||
|
|
||
| # NOTE(Charlie): We need `torch.cuda.is_available()` to be True to import SGLang. Otherwise, it requires | ||
| # importing vllm. See https://github.com/sgl-project/sglang/blob/v0.4.8.post1/python/sglang/srt/layers/quantization/utils.py#L11-L17 | ||
| # Similar comment: https://github.com/volcengine/verl/blob/9cc307767b0c787e8f5ef581dac929f7bde044ef/verl/workers/fsdp_workers.py#L520-L527 | ||
| @ray.remote | ||
| def get_sglang_engine(): | ||
| # A workaround to avoid importing vllm is to give this task a GPU. | ||
SumanthRH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import os | ||
|
|
||
| before_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
| from skyrl_train.inference_engines.sglang.sglang_engine import SGLangRayActor | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] = before_cuda_visible_devices | ||
|
|
||
|
Comment on lines
+162
to
+167
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like another issue with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are going to fix it, see ray-project/ray#54868
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great! |
||
| actor_class = SGLangRayActor | ||
| engine = actor_class.options( | ||
| num_cpus=num_gpus, | ||
| num_gpus=num_gpus, | ||
| scheduling_strategy=scheduling_strategy, | ||
| ).remote( | ||
| model_path=pretrain, | ||
| tp_size=tensor_parallel_size, | ||
| mem_fraction_static=gpu_memory_utilization, | ||
| random_seed=seed + i, | ||
| context_length=max_model_len, | ||
| disable_radix_cache=not enable_prefix_caching, | ||
| dtype=model_dtype, | ||
| trust_remote_code=True, | ||
| max_prefill_tokens=max_num_batched_tokens, | ||
| max_running_requests=max_num_seqs, | ||
| # Borrowed from veRL's SGLang rollout | ||
| mm_attention_backend="fa3", | ||
| attention_backend="fa3", | ||
| enable_memory_saver=inference_engine_enable_sleep, | ||
| # Will be popped before instantiating sgl.Engine | ||
| distributed_executor_backend=distributed_executor_backend, | ||
| noset_visible_devices=noset_visible_devices, | ||
| bundle_indices=bundle_indices, | ||
| num_gpus=0.2 if use_hybrid_engine else 1, | ||
| sampling_params=sampling_params, | ||
| tokenizer=tokenizer, | ||
| ) | ||
| return engine | ||
|
|
||
| engine = ray.get(get_sglang_engine.remote()) | ||
|
|
||
| inference_engine_actors.append(engine) | ||
|
|
||
| engines = [RayWrappedInferenceEngine(actor_handle) for actor_handle in inference_engine_actors] | ||
|
|
||
| if vllm_enable_sleep: | ||
| if inference_engine_enable_sleep: | ||
| sleep_refs = [engine.inference_engine_actor.sleep.remote() for engine in engines] | ||
| ray.get(sleep_refs) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.