-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Support server based rollout in Verlengine #4848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 34 commits
10b7227
e79d178
cd89cc3
7e71b16
99be7b2
7c3202d
87302d3
20be928
7b5dfae
181030c
c7b6cf7
b0d9c51
ff8be06
c0b7b7d
80df9c1
10c1f43
9375675
d13e44d
6c01b8b
e10dea9
6877dbc
58cf180
66d236b
e97cd3f
db3937e
0963bcd
c266d4a
dca2e96
dd4ac15
d38ea8d
e148a50
5f77d4b
99dcc14
ae05db5
e78bdfe
ae2130b
128def0
59992df
78542c9
8f95856
eea5eec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Dict, Iterator, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class EngineBase(ABC): | ||
| """ | ||
| Abstract base class for engine interfaces that support generation, weight updating, and memory control. | ||
| This base class provides a unified API for both HTTP-based engines and engines. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def generate( | ||
| self, | ||
| prompt: Optional[Union[List[str], str]] = None, | ||
| sampling_params: Optional[Union[List[Dict], Dict]] = None, | ||
| input_ids: Optional[Union[List[List[int]], List[int]]] = None, | ||
| image_data: Optional[Union[List[str], str]] = None, | ||
| return_logprob: Optional[Union[List[bool], bool]] = False, | ||
| logprob_start_len: Optional[Union[List[int], int]] = None, | ||
| top_logprobs_num: Optional[Union[List[int], int]] = None, | ||
| token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, | ||
| lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, | ||
| custom_logit_processor: Optional[Union[List[str], str]] = None, | ||
| ) -> Union[Dict, Iterator[Dict]]: | ||
| """Generate outputs based on given inputs.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def update_weights_from_tensor( | ||
| self, | ||
| named_tensors: List[Tuple[str, torch.Tensor]], | ||
| load_format: Optional[str] = None, | ||
| flush_cache: bool = True, | ||
| ): | ||
| """Update model weights with in-memory tensor data.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def release_memory_occupation(self): | ||
| """Release GPU memory occupation temporarily.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def resume_memory_occupation(self): | ||
| """Resume GPU memory occupation which is previously released.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def shutdown(self): | ||
| """Shutdown the engine and clean up resources.""" | ||
| pass |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,140 @@ | ||||||
| import base64 | ||||||
| import copy | ||||||
| import dataclasses | ||||||
| import multiprocessing | ||||||
| import pickle | ||||||
| import threading | ||||||
| import time | ||||||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||||||
|
|
||||||
| import requests | ||||||
| import torch | ||||||
| import torch.distributed as dist | ||||||
|
|
||||||
| from sglang.srt.entrypoints.EngineBase import EngineBase | ||||||
| from sglang.srt.entrypoints.http_server import launch_server | ||||||
| from sglang.srt.server_args import ServerArgs | ||||||
| from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree | ||||||
|
|
||||||
|
|
||||||
| def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: | ||||||
|
|
||||||
| p = multiprocessing.Process(target=launch_server, args=(server_args,)) | ||||||
| p.start() | ||||||
|
|
||||||
| base_url = server_args.url() | ||||||
| timeout = 300.0 # Increased timeout to 5 minutes for downloading large models | ||||||
| start_time = time.time() | ||||||
|
|
||||||
| with requests.Session() as session: | ||||||
| while time.time() - start_time < timeout: | ||||||
| try: | ||||||
| headers = { | ||||||
| "Content-Type": "application/json; charset=utf-8", | ||||||
| "Authorization": f"Bearer {server_args.api_key}", | ||||||
| } | ||||||
| response = session.get(f"{base_url}/health_generate", headers=headers) | ||||||
| if response.status_code == 200: | ||||||
| return p | ||||||
| except requests.RequestException: | ||||||
| pass | ||||||
|
|
||||||
| if not p.is_alive(): | ||||||
| raise Exception("Server process terminated unexpectedly.") | ||||||
|
|
||||||
| time.sleep(2) | ||||||
|
|
||||||
| p.terminate() | ||||||
| raise TimeoutError("Server failed to start within the timeout period.") | ||||||
|
|
||||||
|
|
||||||
| class HttpServerEngineAdapter(EngineBase): | ||||||
yitianlian marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| """ | ||||||
| You can use this class to launch a server from a VerlEngine instance. | ||||||
| We recommend using this class only you need to use http server. | ||||||
| Otherwise, you can use Engine directly. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, **kwargs): | ||||||
| self.server_args = ServerArgs(**kwargs) | ||||||
| print(f"launch_server_from_verl_engine {self.server_args.port}") | ||||||
|
||||||
| self.process = launch_server_process(self.server_args) | ||||||
|
|
||||||
| def _make_request(self, endpoint: str, payload: dict = None): | ||||||
|
||||||
| def _make_request(self, endpoint: str, payload: dict = None): | |
| def _make_request(self, endpoint: str, payload: Optional[dict] = None): |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it seems most people will use HTTP instead of HTTPS (indeed wondering whether SGLang supports https today), thus would be great to change doc
(same for other "HTTPS" words)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state that meta data will be transferred by https, but real weights should be copied directly from GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in the newest commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if here we can call the serialization only once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember MultiprocessingSerializer.serialize(named_tensors) can't be posted by HTTP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant something like
x = HttpSerializer.serialize(MultiprocessingSerializer.serialize(named_tensors))
response = requests.post(self._url("update_weights_from_tensor"), json={"serialized_named_tensors": [x for _ in range(self.server_args.tp_size)] ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sure! I will fix it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we can just send a single copy of HttpSerializer.serialize(MultiprocessingSerializer.serialize(named_tensors)) to reduce the HTTP payload size? The server can make multiple copies after receiving it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size of serialized tensors is really small, so I think it will not take a long time. Also, it seems that the update_weight_from_tensor entry point doesn't have access to the tp size.
yitianlian marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,16 +12,18 @@ | |
| # limitations under the License. | ||
| # ============================================================================== | ||
| import os | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
| from typing import Dict, List, Literal, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from PIL.Image import Image | ||
| from torch.distributed.tensor import DeviceMesh, DTensor | ||
|
|
||
| from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter | ||
| from sglang.srt.model_executor.model_runner import LocalSerializedTensor | ||
| from sglang.srt.patch_torch import monkey_patch_torch_reductions | ||
| from sglang.srt.server import Engine | ||
| from sglang.srt.server_args import PortArgs, ServerArgs | ||
| from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj | ||
|
|
||
|
|
||
|
|
@@ -30,6 +32,7 @@ def __init__( | |
| self, | ||
| device_mesh_cpu: DeviceMesh, | ||
| nnodes: int = 1, | ||
| backend: Literal["engine", "server"] = "engine", | ||
| **kwargs, | ||
| ): | ||
| monkey_patch_torch_reductions() | ||
|
|
@@ -40,13 +43,25 @@ def __init__( | |
| node_rank = self._tp_rank // tp_size_per_node | ||
| first_rank_in_node = self._tp_rank % tp_size_per_node == 0 | ||
|
|
||
| if first_rank_in_node: | ||
| os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" | ||
| self._engine = Engine( | ||
| **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes | ||
| ) | ||
| # Common engine keyword arguments | ||
| engine_kwargs = dict( | ||
| **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes | ||
| ) | ||
|
|
||
| if backend == "engine": | ||
| if first_rank_in_node: | ||
| os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" | ||
| self._engine = Engine(**engine_kwargs) | ||
| else: | ||
| self._engine = None | ||
|
|
||
| elif backend == "server": | ||
| if self._tp_rank == 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wondering whether it will work for multi node
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will consider multi-node in a new pr.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious whether change to will work directly
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haven't test it. We can do it in next PR 😂
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, though that change looks easy and can be quickly tested, and makes the code a bit more unified |
||
| self._engine = HttpServerEngineAdapter(**engine_kwargs) | ||
| else: | ||
| self._engine = None | ||
| else: | ||
| self._engine = None | ||
| raise ValueError(f"Unsupported backend: {backend}") | ||
|
|
||
| dist.barrier(group=self._device_mesh_cpu.get_group()) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -700,10 +700,18 @@ class UpdateWeightsFromDistributedReqOutput: | |||||
|
|
||||||
| @dataclass | ||||||
| class UpdateWeightsFromTensorReqInput: | ||||||
| # List containing one serialized Dict[str, torch.Tensor] per TP worker | ||||||
| serialized_named_tensors: List[bytes] | ||||||
| load_format: Optional[str] | ||||||
| flush_cache: bool | ||||||
| """Update model weights from tensor input. | ||||||
|
|
||||||
| - Binary data like tensors are base64 encoded | ||||||
|
||||||
| - Data is structured in JSON for easy transmission over HTTP | ||||||
| - No pickle serialization is used for security reasons | ||||||
|
||||||
| """ | ||||||
|
|
||||||
| serialized_named_tensors: List[str] | ||||||
|
||||||
| serialized_named_tensors: List[str] | |
| serialized_named_tensors: List[Union[str, bytes]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this sleep is required? Generally we should not wait for unnecessary sleep time since the server will be launched several times in the training process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is necessary, because it means we check whether the server is launched every two seconds.