Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
10b7227
update http_server_engine and part of the test
yitianlian Mar 27, 2025
e79d178
Add other 3 APIs
jhinpan Mar 28, 2025
cd89cc3
update http_server_engine and test
yitianlian Mar 28, 2025
7e71b16
Merge branch 'main' into feature/http_server_engine
yitianlian Mar 28, 2025
99be7b2
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Mar 28, 2025
7c3202d
Merge branch 'main' into feature/http_server_engine
jhinpan Mar 28, 2025
87302d3
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 2, 2025
20be928
revise most of problems in comments
yitianlian Apr 4, 2025
7b5dfae
revise most of problems in comments
yitianlian Apr 4, 2025
181030c
reduce the serialize number
yitianlian Apr 4, 2025
c7b6cf7
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 4, 2025
b0d9c51
Add Comment for update_weights_from_tensor
jhinpan Apr 5, 2025
ff8be06
Merge branch 'main' into feature/http_server_engine
jhinpan Apr 5, 2025
c0b7b7d
add base_engine(ABC)
yitianlian Apr 5, 2025
80df9c1
add docstring for update_weights_from_tensor
yitianlian Apr 5, 2025
10c1f43
revise some code structure
yitianlian Apr 7, 2025
9375675
revise some code structure
yitianlian Apr 7, 2025
d13e44d
Update comments and CI testing logic
jhinpan Apr 7, 2025
6c01b8b
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 7, 2025
e10dea9
Fix lint check
jhinpan Apr 7, 2025
6877dbc
revise CI testing logic
yitianlian Apr 8, 2025
58cf180
Update base_engine.py
zhaochenyang20 Apr 8, 2025
66d236b
Update http_server_engine.py
zhaochenyang20 Apr 8, 2025
e97cd3f
revise some expression in http_server_engine
yitianlian Apr 8, 2025
db3937e
Refactoring Code Structure
jhinpan Apr 9, 2025
0963bcd
Merge branch 'main' into feature/http_server_engine
jhinpan Apr 9, 2025
c266d4a
For Sync
jhinpan Apr 9, 2025
dca2e96
Revert MP in Engine
jhinpan Apr 9, 2025
dd4ac15
Merge branch 'main' into feature/http_server_engine
jhinpan Apr 9, 2025
d38ea8d
update method of updating weights
yitianlian Apr 9, 2025
e148a50
Merge branch 'main' into feature/http_server_engine
jhinpan Apr 9, 2025
5f77d4b
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 9, 2025
99dcc14
update name
yitianlian Apr 10, 2025
ae05db5
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 10, 2025
e78bdfe
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 11, 2025
ae2130b
Quick fix for review
jhinpan Apr 11, 2025
128def0
One other HTTP clarification
jhinpan Apr 11, 2025
59992df
update doc
yitianlian Apr 11, 2025
78542c9
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 11, 2025
8f95856
Merge branch 'main' into feature/http_server_engine
jhinpan Apr 11, 2025
eea5eec
Merge branch 'main' into feature/http_server_engine
zhaochenyang20 Apr 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/sglang/srt/entrypoints/EngineBase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, Union

import torch


class EngineBase(ABC):
"""
Abstract base class for engine interfaces that support generation, weight updating, and memory control.
This base class provides a unified API for both HTTP-based engines and engines.
"""

@abstractmethod
def generate(
self,
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs."""
pass

@abstractmethod
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = True,
):
"""Update model weights with in-memory tensor data."""
pass

@abstractmethod
def release_memory_occupation(self):
"""Release GPU memory occupation temporarily."""
pass

@abstractmethod
def resume_memory_occupation(self):
"""Resume GPU memory occupation which is previously released."""
pass

@abstractmethod
def shutdown(self):
"""Shutdown the engine and clean up resources."""
pass
3 changes: 2 additions & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import uvloop

from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
Expand Down Expand Up @@ -78,7 +79,7 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


class Engine:
class Engine(EngineBase):
"""
The entry point to the inference engine.

Expand Down
27 changes: 26 additions & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import os
import threading
import time
from ast import Mult
from http import HTTPStatus
from typing import AsyncIterator, Callable, Dict, Optional
from typing import AsyncIterator, Callable, Dict, Optional, Union

from sglang.srt.model_executor.model_runner import LocalSerializedTensor

# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
Expand Down Expand Up @@ -60,6 +63,7 @@
SetInternalStateReq,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Expand All @@ -80,6 +84,7 @@
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
MultiprocessingSerializer,
add_api_key_middleware,
add_prometheus_middleware,
delete_directory,
Expand Down Expand Up @@ -411,6 +416,26 @@ async def init_weights_update_group(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/update_weights_from_tensor")
async def update_weights_from_tensor(
obj: UpdateWeightsFromTensorReqInput, request: Request
):
"""Update the weights from tensor inplace without re-launching the server.
Notes:
1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors.
2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model.
3. Any binary data in the named tensors should be base64 encoded.
"""

success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
obj, request
)
content = {"success": success, "message": message}
return ORJSONResponse(
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
)


@app.post("/update_weights_from_distributed")
async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request
Expand Down
140 changes: 140 additions & 0 deletions python/sglang/srt/entrypoints/http_server_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import base64
import copy
import dataclasses
import multiprocessing
import pickle
import threading
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
import torch
import torch.distributed as dist

from sglang.srt.entrypoints.EngineBase import EngineBase
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree


def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:

p = multiprocessing.Process(target=launch_server, args=(server_args,))
p.start()

base_url = server_args.url()
timeout = 300.0 # Increased timeout to 5 minutes for downloading large models
start_time = time.time()

with requests.Session() as session:
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {server_args.api_key}",
}
response = session.get(f"{base_url}/health_generate", headers=headers)
if response.status_code == 200:
return p
except requests.RequestException:
pass

if not p.is_alive():
raise Exception("Server process terminated unexpectedly.")

time.sleep(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sleep is required? Generally we should not wait for unnecessary sleep time since the server will be launched several times in the training process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is necessary, because it means we check whether the server is launched every two seconds.


p.terminate()
raise TimeoutError("Server failed to start within the timeout period.")


class HttpServerEngineAdapter(EngineBase):
"""
You can use this class to launch a server from a VerlEngine instance.
We recommend using this class only you need to use http server.
Otherwise, you can use Engine directly.
"""

def __init__(self, **kwargs):
self.server_args = ServerArgs(**kwargs)
print(f"launch_server_from_verl_engine {self.server_args.port}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe remove this line since the names are wrong etc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which line to remove? This one is a bit confusing to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to remove the print line

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes (or at least rename the word "launch_server_from_verl_engine" which seems to be the old name of a function that we called

self.process = launch_server_process(self.server_args)

def _make_request(self, endpoint: str, payload: dict = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _make_request(self, endpoint: str, payload: dict = None):
def _make_request(self, endpoint: str, payload: Optional[dict] = None):

"""Make a POST request to the specified endpoint with the given payload.
Args:
endpoint: The API endpoint to call
payload: The JSON payload to send (default: empty dict)
Returns:
The JSON response from the server
"""
url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
response = requests.post(url, json=payload or {})
response.raise_for_status()
return response.json()

def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = False,
):
"""
Update model weights from tensor data. The HTTPS server will only post meta data, and the real weights will be copied directly from GPUs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems most people will use HTTP instead of HTTPS (indeed wondering whether SGLang supports https today), thus would be great to change doc

(same for other "HTTPS" words)

Note: The model should be on GPUs rather than CPU for this functionality to work properly.
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
Comment on lines +89 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state that meta data will be transferred by https, but real weights should be copied directly from GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the newest commit

"""

return self._make_request(
"update_weights_from_tensor",
{
"serialized_named_tensors": [
MultiprocessingSerializer.serialize(named_tensors, output_str=True)
for _ in range(self.server_args.tp_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if here we can call the serialization only once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember MultiprocessingSerializer.serialize(named_tensors) can't be posted by HTTP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something like

x = HttpSerializer.serialize(MultiprocessingSerializer.serialize(named_tensors))
response = requests.post(self._url("update_weights_from_tensor"), json={"serialized_named_tensors": [x for _ in range(self.server_args.tp_size)] ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sure! I will fix it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can just send a single copy of HttpSerializer.serialize(MultiprocessingSerializer.serialize(named_tensors)) to reduce the HTTP payload size? The server can make multiple copies after receiving it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of serialized tensors is really small, so I think it will not take a long time. Also, it seems that the update_weight_from_tensor entry point doesn't have access to the tp size.

],
"load_format": load_format,
"flush_cache": flush_cache,
},
)

def shutdown(self):
kill_process_tree(self.process.pid)

def generate(
self,
prompt=None,
sampling_params=None,
input_ids=None,
image_data=None,
return_logprob=False,
logprob_start_len=None,
top_logprobs_num=None,
token_ids_logprob=None,
lora_path=None,
custom_logit_processor=None,
):
payload = {
"text": prompt,
"sampling_params": sampling_params,
"input_ids": input_ids,
"image_data": image_data,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"token_ids_logprob": token_ids_logprob,
"lora_path": lora_path,
"custom_logit_processor": custom_logit_processor,
}
# Filter out None values
payload = {k: v for k, v in payload.items() if v is not None}

return self._make_request("generate", payload)

def release_memory_occupation(self):
return self._make_request("release_memory_occupation")

def resume_memory_occupation(self):
return self._make_request("resume_memory_occupation")
29 changes: 22 additions & 7 deletions python/sglang/srt/entrypoints/verl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.distributed as dist
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor

from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj


Expand All @@ -30,6 +32,7 @@ def __init__(
self,
device_mesh_cpu: DeviceMesh,
nnodes: int = 1,
backend: Literal["engine", "server"] = "engine",
**kwargs,
):
monkey_patch_torch_reductions()
Expand All @@ -40,13 +43,25 @@ def __init__(
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0

if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
# Common engine keyword arguments
engine_kwargs = dict(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)

if backend == "engine":
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(**engine_kwargs)
else:
self._engine = None

elif backend == "server":
if self._tp_rank == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering whether it will work for multi node

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will consider multi-node in a new pr.

Copy link
Collaborator

@fzyzcjy fzyzcjy Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious whether change to

        engine_class = Engine or HttpServerEngineAdapter depend on the `backend`
        if first_rank_in_node:
            os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
            self._engine = engine_class(
                **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
            )
        else:
            self._engine = None

will work directly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't test it. We can do it in next PR 😂

Copy link
Collaborator

@fzyzcjy fzyzcjy Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, though that change looks easy and can be quickly tested, and makes the code a bit more unified

self._engine = HttpServerEngineAdapter(**engine_kwargs)
else:
self._engine = None
else:
self._engine = None
raise ValueError(f"Unsupported backend: {backend}")

dist.barrier(group=self._device_mesh_cpu.get_group())

Expand Down
16 changes: 12 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,18 @@ class UpdateWeightsFromDistributedReqOutput:

@dataclass
class UpdateWeightsFromTensorReqInput:
# List containing one serialized Dict[str, torch.Tensor] per TP worker
serialized_named_tensors: List[bytes]
load_format: Optional[str]
flush_cache: bool
"""Update model weights from tensor input.

- Binary data like tensors are base64 encoded
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not base64 encoded when this object is created from Engine...

- Data is structured in JSON for easy transmission over HTTP
- No pickle serialization is used for security reasons
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well there is pickle indeed... (to serialize torch.Tensors)

"""

serialized_named_tensors: List[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
serialized_named_tensors: List[str]
serialized_named_tensors: List[Union[str, bytes]]

# Optional format specification for loading
load_format: Optional[str] = None
# Whether to flush the cache after updating weights
flush_cache: bool = True


@dataclass
Expand Down
33 changes: 31 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,14 +1480,43 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:

class MultiprocessingSerializer:
@staticmethod
def serialize(obj):
def serialize(obj, output_str: bool = False):
"""
Serialize a Python object using ForkingPickler.

Args:
obj: The object to serialize.
output_str (bool): If True, return a base64-encoded string instead of raw bytes.

Returns:
bytes or str: The serialized object.
"""
buf = io.BytesIO()
ForkingPickler(buf).dump(obj)
buf.seek(0)
return buf.read()
output = buf.read()

if output_str:
# Convert bytes to base64-encoded string
output = base64.b64encode(output).decode("utf-8")

return output

@staticmethod
def deserialize(data):
"""
Deserialize a previously serialized object.

Args:
data (bytes or str): The serialized data, optionally base64-encoded.

Returns:
The deserialized Python object.
"""
if isinstance(data, str):
# Decode base64 string to bytes
data = base64.b64decode(data)

return ForkingPickler.loads(data)


Expand Down
Loading
Loading