From 10b72272414bf6edebbac589aea63735540e8484 Mon Sep 17 00:00:00 2001 From: yitianlian Date: Thu, 27 Mar 2025 12:24:03 +0000 Subject: [PATCH 01/25] update http_server_engine and part of the test --- python/sglang/srt/entrypoints/http_server.py | 27 ++ .../srt/entrypoints/http_server_engine.py | 121 +++++++ python/sglang/srt/entrypoints/verl_engine.py | 19 +- python/test_verl_engine_server.py | 315 ++++++++++++++++++ 4 files changed, 481 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/entrypoints/http_server_engine.py create mode 100644 python/test_verl_engine_server.py diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1024370a10c..68f08e4d725 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -60,6 +60,7 @@ SetInternalStateReq, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, VertexGenerateReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -411,6 +412,32 @@ async def init_weights_update_group( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +def deserialize_from_http(encoded_data): + import base64 + import pickle + + # Convert from base64 string back to original data + pickled = base64.b64decode(encoded_data) + return pickle.loads(pickled) + + +@app.post("/update_weights_from_tensor") +async def update_weights_from_tensor( + obj: UpdateWeightsFromTensorReqInput, request: Request +): + obj.serialized_named_tensors = [ + deserialize_from_http(item) for item in obj.serialized_named_tensors + ] + success, message = await _global_state.tokenizer_manager.update_weights_from_tensor( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/update_weights_from_distributed") async def update_weights_from_distributed( obj: UpdateWeightsFromDistributedReqInput, request: Request diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py new file mode 100644 index 00000000000..8b7743b8ad4 --- /dev/null +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -0,0 +1,121 @@ +import base64 +import copy +import pickle +import threading +import time +from typing import Dict, List, Optional, Tuple, Union + +import requests +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, DTensor + +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_server, +) + + +def serialize_for_http(data): + # First pickle the data, then convert to base64 for safe HTTP transmission + pickled = pickle.dumps(data) + return base64.b64encode(pickled).decode("utf-8") + + +import dataclasses + + +def server_args_to_launch_params(args: ServerArgs, timeout: float = 60.0): + # 1. model path + model = args.model_path + + # 2. base url + base_url = args.url() + + # 3. timeout + timeout = timeout + + # 4. api key + api_key = args.api_key + + # 5. other args: convert to CLI style, excluding handled keys + exclude_keys = {"model_path", "host", "port", "api_key"} + other_args = [] + + for field in dataclasses.fields(ServerArgs): + key = field.name + if key in exclude_keys: + continue + val = getattr(args, key) + if isinstance(val, bool): + if val: + other_args.append(f"--{key.replace('_', '-')}") + elif val is not None: + if isinstance(val, list): + for v in val: + other_args.extend([f"--{key.replace('_', '-')}", str(v)]) + else: + other_args.extend([f"--{key.replace('_', '-')}", str(val)]) + + return model, base_url, timeout, api_key, other_args + + +class HttpServerEngineAdapter: + def __init__(self, server_args: ServerArgs): + self.server_args = copy.deepcopy(server_args) + self.server_args.port = 2157 + print(f"launch_server_from_verl_engine {self.server_args.port}") + + # server_thread = threading.Thread( + # target=launch_server, + # args=(self.server_args,), + # daemon=True, + # ) + model, base_url, timeout, api_key, other_args = server_args_to_launch_params( + self.server_args + ) + self.process = popen_launch_server( + model=model, + base_url=base_url, + timeout=timeout, + api_key=api_key, + other_args=other_args, + ) + + def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, torch.Tensor]], + load_format: Optional[str] = None, + flush_cache: bool = False, + ): + + # obj = UpdateWeightsFromTensorReqInput( + # serialized_named_tensors=[ + # MultiprocessingSerializer.serialize(named_tensors) + # for _ in range(self.server_args.tp_size) + # ], + # load_format=load_format, + # flush_cache=flush_cache, + # ) + + print(f"update_weights_from_tensor of HttpServerEngineAdapter") + return requests.post( + f"http://localhost:{self.server_args.port}/update_weights_from_tensor", + json={ + "serialized_named_tensors": [ + serialize_for_http( + MultiprocessingSerializer.serialize(named_tensors) + ) + for _ in range(self.server_args.tp_size) + ], + "load_format": load_format, + "flush_cache": flush_cache, + }, + ) + + def shutdown(self): + kill_process_tree(self.process.pid) diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 13f60451e0b..ad7c82b3404 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -18,9 +18,11 @@ import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor +from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj @@ -39,11 +41,26 @@ def __init__( node_rank = self._tp_rank // tp_size_per_node first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - if first_rank_in_node: + if first_rank_in_node and "launch_server" not in kwargs: os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" self._engine = Engine( **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes ) + elif "launch_server" in kwargs and kwargs["launch_server"]: + del kwargs["launch_server"] + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + if self._tp_rank == 0: + self._engine = HttpServerEngineAdapter(server_args) + else: + self._engine = None else: self._engine = None diff --git a/python/test_verl_engine_server.py b/python/test_verl_engine_server.py new file mode 100644 index 00000000000..90f5dd46e25 --- /dev/null +++ b/python/test_verl_engine_server.py @@ -0,0 +1,315 @@ +import multiprocessing +import multiprocessing as mp +import os +import random +import traceback +import unittest +from multiprocessing import Process + +import torch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.api import ( + ShardedStateDictConfig, + ShardingStrategy, + StateDictType, +) +from transformers import AutoModelForCausalLM + +from sglang.srt.entrypoints.verl_engine import VerlEngine +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import is_port_available +from sglang.test.runners import ( + HFRunner, + SRTRunner, + check_close_model_outputs, + get_dtype_str, +) +from sglang.test.test_utils import CustomTestCase, is_in_ci + +_MAX_NEW_TOKENS = 8 +_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] +_TORCH_DTYPE = torch.float16 + +# Set to false to temporarily debug issues unrelated to weight update +_ENABLE_UPDATE_WEIGHTS = True +# _ENABLE_UPDATE_WEIGHTS = False + +# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py? +# CI_MODELS = [ +# dict(model_path="meta-llama/Llama-3.1-8B-Instruct"), +# # Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0 +# # dict(model_path="google/gemma-2-2b"), +# ] +# ALL_OTHER_MODELS = [ +# dict(model_path="meta-llama/Llama-3.2-1B-Instruct"), +# dict(model_path="Qwen/Qwen2-1.5B"), +# dict( +# model_path="Qwen/Qwen2.5-14B-Instruct", +# mem_fraction_static=0.4, +# tp_size=8, +# tight_memory=True, +# decode_tolerance=1.3, +# ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error +# dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3), +# dict(model_path="allenai/OLMo-1B-0724-hf"), +# dict( +# model_path="THUDM/glm-4-9b-chat", +# mem_fraction_static=0.1, +# tp_size=8, +# tight_memory=True, +# ), +# dict(model_path="allenai/OLMo-2-1124-7B-Instruct"), +# dict( +# model_path="ibm-granite/granite-3.0-2b-instruct", +# prefill_tolerance=0.22, +# decode_tolerance=0.22, +# ), +# # Fail to run these models in test_generation_models.py, need to fix that first +# # dict(model_path="openai-community/gpt2"), +# # dict(model_path="microsoft/Phi-3-small-8k-instruct"), +# ] +CI_MODELS = ALL_OTHER_MODELS = [ + dict( + model_path="Qwen/Qwen2.5-1.5B", + tp_size=2, + ) +] + + +class TestVerlEngine(CustomTestCase): + @classmethod + def setUpClass(cls): + multiprocessing.set_start_method("spawn") + + def assert_fragment_e2e_execution( + self, + index: int, + model_path: str, + mem_fraction_static: float = 0.4, + tp_size: int = 2, + tight_memory: bool = False, + prefill_tolerance: float = 0.1, + decode_tolerance: float = 0.1, + ): + master_port = find_available_port(23456) + + print(f"assert_fragment_e2e_execution START {index=} {model_path=}") + + processes = [] + output_reader, output_writer = mp.Pipe(duplex=False) + for tp_rank in range(tp_size): + p = Process( + target=_run_subprocess, + kwargs=dict( + tp_rank=tp_rank, + tp_size=tp_size, + master_port=master_port, + output_writer=output_writer, + model_path=model_path, + mem_fraction_static=mem_fraction_static, + tight_memory=tight_memory, + prefill_tolerance=prefill_tolerance, + decode_tolerance=decode_tolerance, + ), + ) + p.start() + processes.append(p) + + for _ in range(tp_size): + self.assertTrue( + output_reader.recv(), + f"Subprocess has error, please see logs above. ({index=} {model_path=})", + ) + + for p in processes: + p.join() + + def test_ci_models(self): + for index, model_info in enumerate(CI_MODELS): + self.assert_fragment_e2e_execution(index=index, **model_info) + + def test_others(self): + if is_in_ci(): + return + + for index, model_info in enumerate(ALL_OTHER_MODELS): + self.assert_fragment_e2e_execution(index=index, **model_info) + + # def test_adhoc(self): + # self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct") + + +def _run_subprocess( + tp_rank: int, + tp_size: int, + master_port: int, + output_writer, + model_path: str, + mem_fraction_static: float, + tight_memory: bool, + prefill_tolerance: float, + decode_tolerance: float, +): + try: + print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}") + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size) + torch.cuda.set_device(tp_rank) + + mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"]) + inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs) + inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs) + print( + f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}" + ) + + # hf model is used for comparison + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True + ).cuda() + hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True) + + hf_outputs = HFRunner.forward_generation_raw( + base_model=hf_model, + prompts=_PROMPTS, + max_new_tokens=_MAX_NEW_TOKENS, + tokenizer=hf_tokenizer, + lora_paths=None, + torch_dtype=_TORCH_DTYPE, + output_str_only=False, + ) + print( + f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}", + flush=True, + ) + + if _ENABLE_UPDATE_WEIGHTS: + if tight_memory: + hf_model.cpu() + torch.cuda.empty_cache() + + # test update weights + print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True) + fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size) + + engine = VerlEngine( + model_path=model_path, + load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto", + mem_fraction_static=mem_fraction_static, + random_seed=42, + trust_remote_code=True, + dtype=get_dtype_str(_TORCH_DTYPE), + device_mesh_cpu=inference_device_mesh_cpu["tp"], + launch_server=True, + ) + print(f"subprocess[{tp_rank=}] {engine=}", flush=True) + + if _ENABLE_UPDATE_WEIGHTS: + print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) + engine.update_weights_from_tensor( + [(k, v) for k, v in fsdp_state_dict.items()] + ) + + # for enable_batch in [False, True]: + # if enable_batch: + # fn = SRTRunner.batch_forward_generation_raw + # else: + # fn = SRTRunner.forward_generation_raw + + # srt_outputs = fn( + # prompts=_PROMPTS, + # max_new_tokens=_MAX_NEW_TOKENS, + # lora_paths=None, + # engine=engine, + # ) + # print( + # f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}", + # flush=True, + # ) + + # check_close_model_outputs( + # hf_outputs=hf_outputs, + # srt_outputs=srt_outputs, + # prefill_tolerance=prefill_tolerance, + # decode_tolerance=decode_tolerance, + # rouge_l_tolerance=1, + # check_logprobs=not enable_batch, + # debug_text=f"{enable_batch=} {tp_rank=}", + # ) + from openai import OpenAI + + client = OpenAI(api_key="None", base_url="http://localhost:2157/v1") + print(client.models.list().data[0].id) + import requests + + url = f"http://localhost:{2157}/generate" + data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} + response = requests.post(url, json=data) + print(response.json()) + + execution_ok = True + + except Exception as e: + print(f"subprocess[{tp_rank=}] has error: {e}", flush=True) + traceback.print_exc() + execution_ok = False + + output_writer.send(execution_ok) + output_writer.close() + + engine.shutdown() + print(f"subprocess[{tp_rank=}] end", flush=True) + + +# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py +def _get_fsdp_state_dict(hf_model, tp_size: int): + device_mesh = init_device_mesh( + "cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"] + ) + + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + fsdp_model = FSDP( + hf_model, + use_orig_params=True, + auto_wrap_policy=None, + device_id=torch.cuda.current_device(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + cpu_offload=CPUOffload(offload_params=False), + sync_module_states=False, + device_mesh=device_mesh, + ) + print(f"{fsdp_model=}") + + FSDP.set_state_dict_type( + fsdp_model, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + + return fsdp_model.state_dict() + + +# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + + +if __name__ == "__main__": + unittest.main() From e79d178b382efa215353e4512681097a95b731ea Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Fri, 28 Mar 2025 01:46:15 +0000 Subject: [PATCH 02/25] Add other 3 APIs --- .../srt/entrypoints/http_server_engine.py | 77 +++++++++++++++---- python/test_verl_engine_server.py | 22 +++++- 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 8b7743b8ad4..01663f49f02 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -70,11 +70,6 @@ def __init__(self, server_args: ServerArgs): self.server_args.port = 2157 print(f"launch_server_from_verl_engine {self.server_args.port}") - # server_thread = threading.Thread( - # target=launch_server, - # args=(self.server_args,), - # daemon=True, - # ) model, base_url, timeout, api_key, other_args = server_args_to_launch_params( self.server_args ) @@ -93,15 +88,6 @@ def update_weights_from_tensor( flush_cache: bool = False, ): - # obj = UpdateWeightsFromTensorReqInput( - # serialized_named_tensors=[ - # MultiprocessingSerializer.serialize(named_tensors) - # for _ in range(self.server_args.tp_size) - # ], - # load_format=load_format, - # flush_cache=flush_cache, - # ) - print(f"update_weights_from_tensor of HttpServerEngineAdapter") return requests.post( f"http://localhost:{self.server_args.port}/update_weights_from_tensor", @@ -119,3 +105,66 @@ def update_weights_from_tensor( def shutdown(self): kill_process_tree(self.process.pid) + + def generate( + self, + prompt=None, + sampling_params=None, + input_ids=None, + image_data=None, + return_logprob=False, + logprob_start_len=None, + top_logprobs_num=None, + token_ids_logprob=None, + lora_path=None, + custom_logit_processor=None, + ): + """Implements text generation functionality by forwarding the request to the veRL Engine via HTTP. + + This method packages all generation parameters into a JSON payload, filters out any None values, + and sends the request to the locally running server. It then returns the parsed response or + raises an exception if the generation fails. + """ + payload = { + "text": prompt, + "sampling_params": sampling_params, + "input_ids": input_ids, + "image_data": image_data, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "token_ids_logprob": token_ids_logprob, + "lora_path": lora_path, + "custom_logit_processor": custom_logit_processor, + } + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + response = requests.post( + f"http://localhost:{self.server_args.port}/generate", json=payload + ) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Generate request failed: {response.text}") + + def release_memory_occupation(self): + """release memory occupation by HTTP""" + response = requests.post( + f"http://localhost:{self.server_args.port}/release_memory_occupation", + json={}, + ) + if response.status_code != 200: + raise Exception(f"Failed to release memory: {response.text}") + return response + + def resume_memory_occupation(self): + """resume memory occupation by HTTP""" + response = requests.post( + f"http://localhost:{self.server_args.port}/resume_memory_occupation", + json={}, + ) + if response.status_code != 200: + raise Exception(f"Failed to resume memory: {response.text}") + return response diff --git a/python/test_verl_engine_server.py b/python/test_verl_engine_server.py index 90f5dd46e25..b294605ae91 100644 --- a/python/test_verl_engine_server.py +++ b/python/test_verl_engine_server.py @@ -2,6 +2,7 @@ import multiprocessing as mp import os import random +import time import traceback import unittest from multiprocessing import Process @@ -138,9 +139,6 @@ def test_others(self): for index, model_info in enumerate(ALL_OTHER_MODELS): self.assert_fragment_e2e_execution(index=index, **model_info) - # def test_adhoc(self): - # self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct") - def _run_subprocess( tp_rank: int, @@ -241,6 +239,24 @@ def _run_subprocess( # check_logprobs=not enable_batch, # debug_text=f"{enable_batch=} {tp_rank=}", # ) + + # test direct generate API + print(f"subprocess[{tp_rank=}] testing direct generate API") + direct_response = engine.generate( + prompt="Hello, world!", + sampling_params={"temperature": 0.7, "max_new_tokens": 20}, + ) + print(f"Direct generate response: {direct_response}") + + # test memory occupation APIs + print(f"subprocess[{tp_rank=}] testing memory occupation APIs") + engine.release_memory_occupation() + print("Memory released") + time.sleep(1) + engine.resume_memory_occupation() + print("Memory resumed") + + # openai API test for reference from openai import OpenAI client = OpenAI(api_key="None", base_url="http://localhost:2157/v1") From cd89cc3b8d6a9fe5a2b142c246ccaa532e09eea9 Mon Sep 17 00:00:00 2001 From: yitianlian Date: Fri, 28 Mar 2025 04:59:03 +0000 Subject: [PATCH 03/25] update http_server_engine and test --- .../srt/entrypoints/http_server_engine.py | 5 +- .../srt}/test_verl_engine_server.py | 140 ++++++------------ 2 files changed, 50 insertions(+), 95 deletions(-) rename {python => test/srt}/test_verl_engine_server.py (74%) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 01663f49f02..35be39131d5 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -29,7 +29,7 @@ def serialize_for_http(data): import dataclasses -def server_args_to_launch_params(args: ServerArgs, timeout: float = 60.0): +def server_args_to_launch_params(args: ServerArgs, timeout: float = 120.0): # 1. model path model = args.model_path @@ -67,7 +67,6 @@ def server_args_to_launch_params(args: ServerArgs, timeout: float = 60.0): class HttpServerEngineAdapter: def __init__(self, server_args: ServerArgs): self.server_args = copy.deepcopy(server_args) - self.server_args.port = 2157 print(f"launch_server_from_verl_engine {self.server_args.port}") model, base_url, timeout, api_key, other_args = server_args_to_launch_params( @@ -78,7 +77,7 @@ def __init__(self, server_args: ServerArgs): base_url=base_url, timeout=timeout, api_key=api_key, - other_args=other_args, + other_args=other_args + ["--enable-memory-saver"], ) def update_weights_from_tensor( diff --git a/python/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py similarity index 74% rename from python/test_verl_engine_server.py rename to test/srt/test_verl_engine_server.py index b294605ae91..f6879a297e7 100644 --- a/python/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -7,7 +7,9 @@ import unittest from multiprocessing import Process +import requests import torch +from openai import OpenAI from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -21,6 +23,7 @@ from sglang.srt.entrypoints.verl_engine import VerlEngine from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available from sglang.test.runners import ( HFRunner, @@ -36,42 +39,7 @@ # Set to false to temporarily debug issues unrelated to weight update _ENABLE_UPDATE_WEIGHTS = True -# _ENABLE_UPDATE_WEIGHTS = False - -# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py? -# CI_MODELS = [ -# dict(model_path="meta-llama/Llama-3.1-8B-Instruct"), -# # Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0 -# # dict(model_path="google/gemma-2-2b"), -# ] -# ALL_OTHER_MODELS = [ -# dict(model_path="meta-llama/Llama-3.2-1B-Instruct"), -# dict(model_path="Qwen/Qwen2-1.5B"), -# dict( -# model_path="Qwen/Qwen2.5-14B-Instruct", -# mem_fraction_static=0.4, -# tp_size=8, -# tight_memory=True, -# decode_tolerance=1.3, -# ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error -# dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3), -# dict(model_path="allenai/OLMo-1B-0724-hf"), -# dict( -# model_path="THUDM/glm-4-9b-chat", -# mem_fraction_static=0.1, -# tp_size=8, -# tight_memory=True, -# ), -# dict(model_path="allenai/OLMo-2-1124-7B-Instruct"), -# dict( -# model_path="ibm-granite/granite-3.0-2b-instruct", -# prefill_tolerance=0.22, -# decode_tolerance=0.22, -# ), -# # Fail to run these models in test_generation_models.py, need to fix that first -# # dict(model_path="openai-community/gpt2"), -# # dict(model_path="microsoft/Phi-3-small-8k-instruct"), -# ] + CI_MODELS = ALL_OTHER_MODELS = [ dict( model_path="Qwen/Qwen2.5-1.5B", @@ -80,6 +48,21 @@ ] +# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + + +PORT = find_available_port(2345) + + class TestVerlEngine(CustomTestCase): @classmethod def setUpClass(cls): @@ -204,42 +187,15 @@ def _run_subprocess( dtype=get_dtype_str(_TORCH_DTYPE), device_mesh_cpu=inference_device_mesh_cpu["tp"], launch_server=True, + server_args=ServerArgs( + model_path=model_path, + tp_size=tp_size, + port=PORT, + mem_fraction_static=0.5, + ), ) print(f"subprocess[{tp_rank=}] {engine=}", flush=True) - if _ENABLE_UPDATE_WEIGHTS: - print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) - engine.update_weights_from_tensor( - [(k, v) for k, v in fsdp_state_dict.items()] - ) - - # for enable_batch in [False, True]: - # if enable_batch: - # fn = SRTRunner.batch_forward_generation_raw - # else: - # fn = SRTRunner.forward_generation_raw - - # srt_outputs = fn( - # prompts=_PROMPTS, - # max_new_tokens=_MAX_NEW_TOKENS, - # lora_paths=None, - # engine=engine, - # ) - # print( - # f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}", - # flush=True, - # ) - - # check_close_model_outputs( - # hf_outputs=hf_outputs, - # srt_outputs=srt_outputs, - # prefill_tolerance=prefill_tolerance, - # decode_tolerance=decode_tolerance, - # rouge_l_tolerance=1, - # check_logprobs=not enable_batch, - # debug_text=f"{enable_batch=} {tp_rank=}", - # ) - # test direct generate API print(f"subprocess[{tp_rank=}] testing direct generate API") direct_response = engine.generate( @@ -255,18 +211,30 @@ def _run_subprocess( time.sleep(1) engine.resume_memory_occupation() print("Memory resumed") - + time.sleep(1) # openai API test for reference - from openai import OpenAI - - client = OpenAI(api_key="None", base_url="http://localhost:2157/v1") - print(client.models.list().data[0].id) - import requests - - url = f"http://localhost:{2157}/generate" - data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} - response = requests.post(url, json=data) - print(response.json()) + torch.distributed.barrier() + if tp_rank == 0: + client = OpenAI(api_key="None", base_url=f"http://localhost:{PORT}/v1") + print(client.models.list().data[0].id) + url = f"http://localhost:{PORT}/generate" + data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} + response = requests.post(url, json=data) + print(response.json()) + if _ENABLE_UPDATE_WEIGHTS: + print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) + # check_tensor = [(k, v) for k, v in fsdp_state_dict.items()][0] + # update_tensor = [check_tensor[0], torch.zeros_like(check_tensor[1])] + engine.update_weights_from_tensor( + [(k, v) for k, v in fsdp_state_dict.items()] + ) + # if tp_rank == 0: + # response = requests.get( + # f"http://localhost:{PORT}/get_weights_by_name", + # json={"name": list(fsdp_state_dict.keys())[0], "truncate_size": 5}, + # timeout=20, + # ) + # print(response.json()) execution_ok = True @@ -315,17 +283,5 @@ def _get_fsdp_state_dict(hf_model, tp_size: int): return fsdp_model.state_dict() -# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code -def find_available_port(base_port: int): - port = base_port + random.randint(100, 1000) - while True: - if is_port_available(port): - return port - if port < 60000: - port += 42 - else: - port -= 43 - - if __name__ == "__main__": unittest.main() From 20be9288405ff43e1b1a8b42696fbef460b03723 Mon Sep 17 00:00:00 2001 From: yitianlian Date: Fri, 4 Apr 2025 14:06:40 +0000 Subject: [PATCH 04/25] revise most of problems in comments --- python/sglang/srt/entrypoints/http_server.py | 39 ++-- .../srt/entrypoints/http_server_engine.py | 170 +++++++++--------- python/sglang/srt/entrypoints/verl_engine.py | 33 ++-- python/sglang/srt/utils.py | 12 ++ test/srt/test_verl_engine_server.py | 18 +- 5 files changed, 145 insertions(+), 127 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index eb19c0c77e3..a7e88d4416a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -26,11 +26,13 @@ import threading import time from http import HTTPStatus -from typing import AsyncIterator, Callable, Dict, Optional +from typing import AsyncIterator, Callable, Dict, Optional, Union # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +import base64 +import pickle from contextlib import asynccontextmanager import numpy as np @@ -81,6 +83,7 @@ from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + HttpSerializer, add_api_key_middleware, add_prometheus_middleware, delete_directory, @@ -412,30 +415,30 @@ async def init_weights_update_group( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) -def deserialize_from_http(encoded_data): - import base64 - import pickle - - # Convert from base64 string back to original data - pickled = base64.b64decode(encoded_data) - return pickle.loads(pickled) - - @app.post("/update_weights_from_tensor") async def update_weights_from_tensor( - obj: UpdateWeightsFromTensorReqInput, request: Request + obj: Union[UpdateWeightsFromTensorReqInput, str], request: Request ): - obj.serialized_named_tensors = [ - deserialize_from_http(item) for item in obj.serialized_named_tensors - ] + if isinstance(obj, str): + try: + obj = HttpSerializer.deserialize(obj) + except Exception as e: + return ORJSONResponse( + {"success": False, "message": f"Failed to decode input: {str(e)}"}, + status_code=HTTPStatus.BAD_REQUEST, + ) + else: + obj.serialized_named_tensors = [ + HttpSerializer.deserialize(item) for item in obj.serialized_named_tensors + ] + success, message = await _global_state.tokenizer_manager.update_weights_from_tensor( obj, request ) content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + return ORJSONResponse( + content, status_code=200 if success else HTTPStatus.BAD_REQUEST + ) @app.post("/update_weights_from_distributed") diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 35be39131d5..ed3833ed035 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -1,84 +1,107 @@ import base64 import copy +import dataclasses +import multiprocessing import pickle import threading import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import requests import torch import torch.distributed as dist -from torch.distributed.tensor import DeviceMesh, DTensor from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - popen_launch_server, +from sglang.srt.utils import ( + HttpSerializer, + MultiprocessingSerializer, + kill_process_tree, ) +from sglang.test.test_utils import popen_launch_server -def serialize_for_http(data): - # First pickle the data, then convert to base64 for safe HTTP transmission - pickled = pickle.dumps(data) - return base64.b64encode(pickled).decode("utf-8") +def launch_server_worker(server_args: ServerArgs): + launch_server(server_args) -import dataclasses +def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: + + p = multiprocessing.Process(target=launch_server_worker, args=(server_args,)) + p.start() + + base_url = server_args.url() + timeout = 120.0 + start_time = time.time() + + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {server_args.api_key}", + } + response = session.get(f"{base_url}/health_generate", headers=headers) + if response.status_code == 200: + return p + except requests.RequestException: + pass + if not p.is_alive(): + raise Exception("Server process terminated unexpectedly.") -def server_args_to_launch_params(args: ServerArgs, timeout: float = 120.0): - # 1. model path - model = args.model_path + time.sleep(2) - # 2. base url - base_url = args.url() + p.terminate() + raise TimeoutError("Server failed to start within the timeout period.") - # 3. timeout - timeout = timeout - # 4. api key - api_key = args.api_key +# def convert_dataclass_to_argparse( +# args: Any, exclude_keys: set[str] = set() +# ) -> List[str]: +# """Convert a dataclass instance into a list of CLI-style arguments.""" +# cli_args = [] +# for field in dataclasses.fields(args): +# key = field.name +# if key in exclude_keys: +# continue +# val = getattr(args, key) +# cli_key = f"--{key.replace('_', '-')}" +# if isinstance(val, bool): +# if val: +# cli_args.append(cli_key) +# elif val is not None: +# if isinstance(val, list): +# for v in val: +# cli_args.extend([cli_key, str(v)]) +# else: +# cli_args.extend([cli_key, str(val)]) +# return cli_args - # 5. other args: convert to CLI style, excluding handled keys - exclude_keys = {"model_path", "host", "port", "api_key"} - other_args = [] - for field in dataclasses.fields(ServerArgs): - key = field.name - if key in exclude_keys: - continue - val = getattr(args, key) - if isinstance(val, bool): - if val: - other_args.append(f"--{key.replace('_', '-')}") - elif val is not None: - if isinstance(val, list): - for v in val: - other_args.extend([f"--{key.replace('_', '-')}", str(v)]) - else: - other_args.extend([f"--{key.replace('_', '-')}", str(val)]) +# def server_args_to_launch_params(args: Any, timeout: float = 120.0): +# model = args.model_path +# base_url = args.url() +# timeout = timeout +# api_key = args.api_key - return model, base_url, timeout, api_key, other_args +# exclude_keys = {"model_path", "host", "port", "api_key"} +# other_args = convert_dataclass_to_argparse(args, exclude_keys=exclude_keys) + +# return model, base_url, timeout, api_key, other_args class HttpServerEngineAdapter: - def __init__(self, server_args: ServerArgs): - self.server_args = copy.deepcopy(server_args) + def __init__(self, **kwargs): + self.server_args = ServerArgs(**kwargs) + # self.server_args = copy.deepcopy(server_args) print(f"launch_server_from_verl_engine {self.server_args.port}") + self.process = launch_server_process(self.server_args) - model, base_url, timeout, api_key, other_args = server_args_to_launch_params( - self.server_args - ) - self.process = popen_launch_server( - model=model, - base_url=base_url, - timeout=timeout, - api_key=api_key, - other_args=other_args + ["--enable-memory-saver"], - ) + def _url(self, path: str) -> str: + """Construct full URL for server endpoint.""" + return f"http://{self.server_args.host}:{self.server_args.port}/{path}" def update_weights_from_tensor( self, @@ -88,11 +111,11 @@ def update_weights_from_tensor( ): print(f"update_weights_from_tensor of HttpServerEngineAdapter") - return requests.post( - f"http://localhost:{self.server_args.port}/update_weights_from_tensor", + response = requests.post( + self._url("update_weights_from_tensor"), json={ "serialized_named_tensors": [ - serialize_for_http( + HttpSerializer.serialize( MultiprocessingSerializer.serialize(named_tensors) ) for _ in range(self.server_args.tp_size) @@ -101,6 +124,8 @@ def update_weights_from_tensor( "flush_cache": flush_cache, }, ) + response.raise_for_status() + return response.json() def shutdown(self): kill_process_tree(self.process.pid) @@ -118,12 +143,6 @@ def generate( lora_path=None, custom_logit_processor=None, ): - """Implements text generation functionality by forwarding the request to the veRL Engine via HTTP. - - This method packages all generation parameters into a JSON payload, filters out any None values, - and sends the request to the locally running server. It then returns the parsed response or - raises an exception if the generation fails. - """ payload = { "text": prompt, "sampling_params": sampling_params, @@ -139,31 +158,16 @@ def generate( # Filter out None values payload = {k: v for k, v in payload.items() if v is not None} - response = requests.post( - f"http://localhost:{self.server_args.port}/generate", json=payload - ) - - if response.status_code == 200: - return response.json() - else: - raise Exception(f"Generate request failed: {response.text}") + response = requests.post(self._url("generate"), json=payload) + response.raise_for_status() + return response.json() def release_memory_occupation(self): - """release memory occupation by HTTP""" - response = requests.post( - f"http://localhost:{self.server_args.port}/release_memory_occupation", - json={}, - ) - if response.status_code != 200: - raise Exception(f"Failed to release memory: {response.text}") - return response + response = requests.post(self._url("release_memory_occupation"), json={}) + response.raise_for_status() + return response.json() def resume_memory_occupation(self): - """resume memory occupation by HTTP""" - response = requests.post( - f"http://localhost:{self.server_args.port}/resume_memory_occupation", - json={}, - ) - if response.status_code != 200: - raise Exception(f"Failed to resume memory: {response.text}") - return response + response = requests.post(self._url("resume_memory_occupation"), json={}) + response.raise_for_status() + return response.json() diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index ad7c82b3404..a69c135a084 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import torch import torch.distributed as dist @@ -31,6 +31,7 @@ def __init__( self, device_mesh_cpu: DeviceMesh, nnodes: int = 1, + backend: Literal["engine", "server"] = "engine", **kwargs, ): monkey_patch_torch_reductions() @@ -41,28 +42,24 @@ def __init__( node_rank = self._tp_rank // tp_size_per_node first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - if first_rank_in_node and "launch_server" not in kwargs: - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine( - **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes - ) - elif "launch_server" in kwargs and kwargs["launch_server"]: - del kwargs["launch_server"] - if "server_args" in kwargs: - # Directly load server_args - server_args = kwargs["server_args"] + if backend == "engine": + if first_rank_in_node: + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" + self._engine = Engine( + **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes + ) else: - # Construct server_args from kwargs - if "log_level" not in kwargs: - # Do not print logs by default - kwargs["log_level"] = "error" - server_args = ServerArgs(**kwargs) + self._engine = None + + elif backend == "server": if self._tp_rank == 0: - self._engine = HttpServerEngineAdapter(server_args) + self._engine = HttpServerEngineAdapter( + **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes + ) else: self._engine = None else: - self._engine = None + raise ValueError(f"Unsupported backend: {backend}") dist.barrier(group=self._device_mesh_cpu.get_group()) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 498bc58ccd0..cce98a571b2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1531,6 +1531,18 @@ def deserialize(data): return ForkingPickler.loads(data) +class HttpSerializer: + @staticmethod + def serialize(data): + pickled = pickle.dumps(data) + return base64.b64encode(pickled).decode("utf-8") + + @staticmethod + def deserialize(data): + pickled = base64.b64decode(data) + return pickle.loads(pickled) + + def debug_timing(func): # todo: replace with a more organized instrumentation def wrapper(*args, **kwargs): diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index f6879a297e7..564b7e8ea68 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -186,13 +186,9 @@ def _run_subprocess( trust_remote_code=True, dtype=get_dtype_str(_TORCH_DTYPE), device_mesh_cpu=inference_device_mesh_cpu["tp"], - launch_server=True, - server_args=ServerArgs( - model_path=model_path, - tp_size=tp_size, - port=PORT, - mem_fraction_static=0.5, - ), + backend="server", + enable_memory_saver=True, + port=PORT, ) print(f"subprocess[{tp_rank=}] {engine=}", flush=True) @@ -211,7 +207,7 @@ def _run_subprocess( time.sleep(1) engine.resume_memory_occupation() print("Memory resumed") - time.sleep(1) + time.sleep(2) # openai API test for reference torch.distributed.barrier() if tp_rank == 0: @@ -235,6 +231,12 @@ def _run_subprocess( # timeout=20, # ) # print(response.json()) + print(f"subprocess[{tp_rank=}] testing direct generate API") + direct_response = engine.generate( + prompt="Hello, world!", + sampling_params={"temperature": 0.7, "max_new_tokens": 20}, + ) + print(f"Direct generate response: {direct_response}") execution_ok = True From 7b5dfae93a81cdc4f1bacbb678d780103a15c258 Mon Sep 17 00:00:00 2001 From: yitianlian Date: Fri, 4 Apr 2025 14:12:18 +0000 Subject: [PATCH 05/25] revise most of problems in comments --- .../srt/entrypoints/http_server_engine.py | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index ed3833ed035..8c7ee457825 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -19,7 +19,6 @@ MultiprocessingSerializer, kill_process_tree, ) -from sglang.test.test_utils import popen_launch_server def launch_server_worker(server_args: ServerArgs): @@ -57,41 +56,6 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: raise TimeoutError("Server failed to start within the timeout period.") -# def convert_dataclass_to_argparse( -# args: Any, exclude_keys: set[str] = set() -# ) -> List[str]: -# """Convert a dataclass instance into a list of CLI-style arguments.""" -# cli_args = [] -# for field in dataclasses.fields(args): -# key = field.name -# if key in exclude_keys: -# continue -# val = getattr(args, key) -# cli_key = f"--{key.replace('_', '-')}" -# if isinstance(val, bool): -# if val: -# cli_args.append(cli_key) -# elif val is not None: -# if isinstance(val, list): -# for v in val: -# cli_args.extend([cli_key, str(v)]) -# else: -# cli_args.extend([cli_key, str(val)]) -# return cli_args - - -# def server_args_to_launch_params(args: Any, timeout: float = 120.0): -# model = args.model_path -# base_url = args.url() -# timeout = timeout -# api_key = args.api_key - -# exclude_keys = {"model_path", "host", "port", "api_key"} -# other_args = convert_dataclass_to_argparse(args, exclude_keys=exclude_keys) - -# return model, base_url, timeout, api_key, other_args - - class HttpServerEngineAdapter: def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) From 181030c5dc89a72918a415e6b91757022acc886d Mon Sep 17 00:00:00 2001 From: yitianlian Date: Fri, 4 Apr 2025 15:07:09 +0000 Subject: [PATCH 06/25] reduce the serialize number --- python/sglang/srt/entrypoints/http_server_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 8c7ee457825..5f6c17692e2 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -75,14 +75,15 @@ def update_weights_from_tensor( ): print(f"update_weights_from_tensor of HttpServerEngineAdapter") + serialized_named_tensors = HttpSerializer.serialize( + MultiprocessingSerializer.serialize(named_tensors) + ) + response = requests.post( self._url("update_weights_from_tensor"), json={ "serialized_named_tensors": [ - HttpSerializer.serialize( - MultiprocessingSerializer.serialize(named_tensors) - ) - for _ in range(self.server_args.tp_size) + serialized_named_tensors for _ in range(self.server_args.tp_size) ], "load_format": load_format, "flush_cache": flush_cache, From b0d9c5142f31c78bd01d43178012a713b0d3dfba Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Sat, 5 Apr 2025 01:02:43 +0000 Subject: [PATCH 07/25] Add Comment for update_weights_from_tensor --- python/sglang/srt/entrypoints/http_server_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 5f6c17692e2..9d84e156d68 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -73,6 +73,12 @@ def update_weights_from_tensor( load_format: Optional[str] = None, flush_cache: bool = False, ): + """ + Update model weights from tensor data. + + Note: The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ print(f"update_weights_from_tensor of HttpServerEngineAdapter") serialized_named_tensors = HttpSerializer.serialize( From c0b7b7d807271dd1238f96a7d14cc09367dfd0ff Mon Sep 17 00:00:00 2001 From: yitianlian Date: Sat, 5 Apr 2025 02:31:19 +0000 Subject: [PATCH 08/25] add base_engine(ABC) --- python/sglang/srt/entrypoints/base_engine.py | 53 +++++++++++++++++++ python/sglang/srt/entrypoints/engine.py | 3 +- .../srt/entrypoints/http_server_engine.py | 6 +-- 3 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/entrypoints/base_engine.py diff --git a/python/sglang/srt/entrypoints/base_engine.py b/python/sglang/srt/entrypoints/base_engine.py new file mode 100644 index 00000000000..3bf3a86c633 --- /dev/null +++ b/python/sglang/srt/entrypoints/base_engine.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import torch + + +class EngineBase(ABC): + """ + Abstract base class for engine interfaces that support generation, weight updating, and memory control. + This base class provides a unified API for both HTTP-based engines and engines. + """ + + @abstractmethod + def generate( + self, + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + image_data: Optional[Union[List[str], str]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + ) -> Union[Dict, Iterator[Dict]]: + """Generate text based on given inputs.""" + pass + + @abstractmethod + def update_weights_from_tensor( + self, + named_tensors: List[Tuple[str, torch.Tensor]], + load_format: Optional[str] = None, + flush_cache: bool = True, + ): + """Update model weights with in-memory tensor data.""" + pass + + @abstractmethod + def release_memory_occupation(self): + """Release GPU memory occupation temporarily.""" + pass + + @abstractmethod + def resume_memory_occupation(self): + """Resume GPU memory occupation previously released.""" + pass + + @abstractmethod + def shutdown(self): + """Shutdown the engine and clean up resources.""" + pass diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b92c6ecdbb5..bf9ca49e5c4 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -37,6 +37,7 @@ import uvloop from sglang.srt.code_completion_parser import load_completion_template_for_openai_api +from sglang.srt.entrypoints.base_engine import EngineBase from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) @@ -77,7 +78,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -class Engine: +class Engine(EngineBase): """ The entry point to the inference engine. diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 9d84e156d68..443637f7de6 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -11,8 +11,8 @@ import torch import torch.distributed as dist +from sglang.srt.entrypoints.base_engine import EngineBase from sglang.srt.entrypoints.http_server import launch_server -from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( HttpSerializer, @@ -56,7 +56,7 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: raise TimeoutError("Server failed to start within the timeout period.") -class HttpServerEngineAdapter: +class HttpServerEngineAdapter(EngineBase): def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) # self.server_args = copy.deepcopy(server_args) @@ -75,7 +75,7 @@ def update_weights_from_tensor( ): """ Update model weights from tensor data. - + Note: The model should be on GPUs rather than CPU for this functionality to work properly. If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. """ From 80df9c189b6a17b9389f544c09ed0f5931735f2b Mon Sep 17 00:00:00 2001 From: yitianlian Date: Sat, 5 Apr 2025 02:38:45 +0000 Subject: [PATCH 09/25] add docstring for update_weights_from_tensor --- python/sglang/srt/entrypoints/http_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a7e88d4416a..ad2feb7759f 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -419,6 +419,10 @@ async def init_weights_update_group( async def update_weights_from_tensor( obj: Union[UpdateWeightsFromTensorReqInput, str], request: Request ): + """Update the weights from tensor inplace without re-launching the server. + Notes: + Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. + """ if isinstance(obj, str): try: obj = HttpSerializer.deserialize(obj) From 10c1f431e154a8b14915935eec4afd64f0b1ae86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Mon, 7 Apr 2025 15:38:38 +0000 Subject: [PATCH 10/25] revise some code structure --- python/sglang/srt/entrypoints/http_server.py | 3 +- .../srt/entrypoints/http_server_engine.py | 6 +- python/sglang/srt/entrypoints/verl_engine.py | 4 +- python/sglang/test/test_utils.py | 12 ++- test/srt/test_verl_engine.py | 13 +-- test/srt/test_verl_engine_server.py | 96 +++++++++---------- 6 files changed, 62 insertions(+), 72 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index ad2feb7759f..2e60a80e3ae 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -421,7 +421,8 @@ async def update_weights_from_tensor( ): """Update the weights from tensor inplace without re-launching the server. Notes: - Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. + 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. + 2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. """ if isinstance(obj, str): try: diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 443637f7de6..bb0118f01cd 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -31,7 +31,7 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: p.start() base_url = server_args.url() - timeout = 120.0 + timeout = 180.0 start_time = time.time() with requests.Session() as session: @@ -56,7 +56,7 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: raise TimeoutError("Server failed to start within the timeout period.") -class HttpServerEngineAdapter(EngineBase): +class HttpServerEngineForRL(EngineBase): def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) # self.server_args = copy.deepcopy(server_args) @@ -80,7 +80,7 @@ def update_weights_from_tensor( If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. """ - print(f"update_weights_from_tensor of HttpServerEngineAdapter") + print(f"update_weights_from_tensor of HttpServerEngineForRL") serialized_named_tensors = HttpSerializer.serialize( MultiprocessingSerializer.serialize(named_tensors) ) diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index a69c135a084..966f843cc98 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -18,7 +18,7 @@ import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor -from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter +from sglang.srt.entrypoints.http_server_engine import HttpServerEngineForRL from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine @@ -53,7 +53,7 @@ def __init__( elif backend == "server": if self._tp_rank == 0: - self._engine = HttpServerEngineAdapter( + self._engine = HttpServerEngineForRL( **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes ) else: diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index fe876c96096..e524f92e8d5 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -25,7 +25,7 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry +from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry, is_port_available from sglang.test.run_eval import run_eval from sglang.utils import get_exception_traceback @@ -102,6 +102,16 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None) pred = res.json()["generated_text"][0] return pred +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None): assert url is not None diff --git a/test/srt/test_verl_engine.py b/test/srt/test_verl_engine.py index 72c0d5225f7..f600aa3a6b3 100644 --- a/test/srt/test_verl_engine.py +++ b/test/srt/test_verl_engine.py @@ -27,7 +27,7 @@ check_close_model_outputs, get_dtype_str, ) -from sglang.test.test_utils import CustomTestCase, is_in_ci +from sglang.test.test_utils import CustomTestCase, is_in_ci, find_available_port _MAX_NEW_TOKENS = 8 _PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] @@ -282,17 +282,6 @@ def _get_fsdp_state_dict(hf_model, tp_size: int): return fsdp_model.state_dict() -# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code -def find_available_port(base_port: int): - port = base_port + random.randint(100, 1000) - while True: - if is_port_available(port): - return port - if port < 60000: - port += 42 - else: - port -= 43 - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 564b7e8ea68..1baa475cc95 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -14,24 +14,15 @@ from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp.api import ( - ShardedStateDictConfig, - ShardingStrategy, - StateDictType, -) +from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType from transformers import AutoModelForCausalLM from sglang.srt.entrypoints.verl_engine import VerlEngine from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.test.runners import ( - HFRunner, - SRTRunner, - check_close_model_outputs, - get_dtype_str, -) -from sglang.test.test_utils import CustomTestCase, is_in_ci +from sglang.test.runners import HFRunner, SRTRunner, check_close_model_outputs, get_dtype_str +from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci _MAX_NEW_TOKENS = 8 _PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] @@ -40,26 +31,39 @@ # Set to false to temporarily debug issues unrelated to weight update _ENABLE_UPDATE_WEIGHTS = True -CI_MODELS = ALL_OTHER_MODELS = [ +CI_MODELS = [ + dict(model_path="meta-llama/Llama-3.1-8B-Instruct"), + # Fail to run gemma-2-2b after transformers==4.48.3 -> 4.50.0 + # dict(model_path="google/gemma-2-2b"), +] +ALL_OTHER_MODELS = [ + # dict(model_path="meta-llama/Llama-3.2-1B-Instruct"), + dict(model_path="Qwen/Qwen2-1.5B"), dict( - model_path="Qwen/Qwen2.5-1.5B", - tp_size=2, - ) + model_path="Qwen/Qwen2.5-14B-Instruct", + mem_fraction_static=0.4, + tp_size=8, + tight_memory=True, + decode_tolerance=1.3, + ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error + dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3), + # dict(model_path="allenai/OLMo-1B-0724-hf"), + # dict( + # model_path="THUDM/glm-4-9b-chat", + # mem_fraction_static=0.1, + # tp_size=8, + # tight_memory=True, + # ), + # dict(model_path="allenai/OLMo-2-1124-7B-Instruct"), + # dict( + # model_path="ibm-granite/granite-3.0-2b-instruct", + # prefill_tolerance=0.22, + # decode_tolerance=0.22, + # ), ] - -# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code -def find_available_port(base_port: int): - port = base_port + random.randint(100, 1000) - while True: - if is_port_available(port): - return port - if port < 60000: - port += 42 - else: - port -= 43 - - +# The port for the server +# 2345 is the default port for the server PORT = find_available_port(2345) @@ -78,6 +82,7 @@ def assert_fragment_e2e_execution( prefill_tolerance: float = 0.1, decode_tolerance: float = 0.1, ): + master_port = find_available_port(23456) print(f"assert_fragment_e2e_execution START {index=} {model_path=}") @@ -111,12 +116,10 @@ def assert_fragment_e2e_execution( for p in processes: p.join() - def test_ci_models(self): - for index, model_info in enumerate(CI_MODELS): - self.assert_fragment_e2e_execution(index=index, **model_info) - - def test_others(self): + def test_models(self): if is_in_ci(): + for index, model_info in enumerate(CI_MODELS): + self.assert_fragment_e2e_execution(index=index, **model_info) return for index, model_info in enumerate(ALL_OTHER_MODELS): @@ -164,13 +167,10 @@ def _run_subprocess( torch_dtype=_TORCH_DTYPE, output_str_only=False, ) - print( - f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}", - flush=True, - ) if _ENABLE_UPDATE_WEIGHTS: if tight_memory: + # If tight_memory is True, we need to move the model to CPU to save memory hf_model.cpu() torch.cuda.empty_cache() @@ -190,8 +190,6 @@ def _run_subprocess( enable_memory_saver=True, port=PORT, ) - print(f"subprocess[{tp_rank=}] {engine=}", flush=True) - # test direct generate API print(f"subprocess[{tp_rank=}] testing direct generate API") direct_response = engine.generate( @@ -204,10 +202,10 @@ def _run_subprocess( print(f"subprocess[{tp_rank=}] testing memory occupation APIs") engine.release_memory_occupation() print("Memory released") - time.sleep(1) + # time.sleep(1) engine.resume_memory_occupation() print("Memory resumed") - time.sleep(2) + # openai API test for reference torch.distributed.barrier() if tp_rank == 0: @@ -219,18 +217,11 @@ def _run_subprocess( print(response.json()) if _ENABLE_UPDATE_WEIGHTS: print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) - # check_tensor = [(k, v) for k, v in fsdp_state_dict.items()][0] - # update_tensor = [check_tensor[0], torch.zeros_like(check_tensor[1])] + engine.update_weights_from_tensor( [(k, v) for k, v in fsdp_state_dict.items()] ) - # if tp_rank == 0: - # response = requests.get( - # f"http://localhost:{PORT}/get_weights_by_name", - # json={"name": list(fsdp_state_dict.keys())[0], "truncate_size": 5}, - # timeout=20, - # ) - # print(response.json()) + print(f"subprocess[{tp_rank=}] testing direct generate API") direct_response = engine.generate( prompt="Hello, world!", @@ -251,7 +242,7 @@ def _run_subprocess( engine.shutdown() print(f"subprocess[{tp_rank=}] end", flush=True) - + # Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py def _get_fsdp_state_dict(hf_model, tp_size: int): device_mesh = init_device_mesh( @@ -284,6 +275,5 @@ def _get_fsdp_state_dict(hf_model, tp_size: int): return fsdp_model.state_dict() - if __name__ == "__main__": unittest.main() From 9375675e253ef2437038750b9579fc2fa4533025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Mon, 7 Apr 2025 15:40:29 +0000 Subject: [PATCH 11/25] revise some code structure --- python/sglang/srt/entrypoints/http_server.py | 2 +- python/sglang/test/test_utils.py | 8 +++++++- test/srt/test_verl_engine.py | 3 +-- test/srt/test_verl_engine_server.py | 18 ++++++++++++++---- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2e60a80e3ae..749de5de49e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -421,7 +421,7 @@ async def update_weights_from_tensor( ): """Update the weights from tensor inplace without re-launching the server. Notes: - 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. + 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. 2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. """ if isinstance(obj, str): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index e524f92e8d5..118c894cb7d 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -25,7 +25,12 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry, is_port_available +from sglang.srt.utils import ( + get_bool_env_var, + is_port_available, + kill_process_tree, + retry, +) from sglang.test.run_eval import run_eval from sglang.utils import get_exception_traceback @@ -102,6 +107,7 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None) pred = res.json()["generated_text"][0] return pred + def find_available_port(base_port: int): port = base_port + random.randint(100, 1000) while True: diff --git a/test/srt/test_verl_engine.py b/test/srt/test_verl_engine.py index f600aa3a6b3..d139cd9c348 100644 --- a/test/srt/test_verl_engine.py +++ b/test/srt/test_verl_engine.py @@ -27,7 +27,7 @@ check_close_model_outputs, get_dtype_str, ) -from sglang.test.test_utils import CustomTestCase, is_in_ci, find_available_port +from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci _MAX_NEW_TOKENS = 8 _PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="] @@ -282,6 +282,5 @@ def _get_fsdp_state_dict(hf_model, tp_size: int): return fsdp_model.state_dict() - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 1baa475cc95..3cb2d737562 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -14,14 +14,23 @@ from torch.distributed.fsdp import CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from torch.distributed.fsdp.api import ( + ShardedStateDictConfig, + ShardingStrategy, + StateDictType, +) from transformers import AutoModelForCausalLM from sglang.srt.entrypoints.verl_engine import VerlEngine from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.test.runners import HFRunner, SRTRunner, check_close_model_outputs, get_dtype_str +from sglang.test.runners import ( + HFRunner, + SRTRunner, + check_close_model_outputs, + get_dtype_str, +) from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci _MAX_NEW_TOKENS = 8 @@ -205,7 +214,7 @@ def _run_subprocess( # time.sleep(1) engine.resume_memory_occupation() print("Memory resumed") - + # openai API test for reference torch.distributed.barrier() if tp_rank == 0: @@ -242,7 +251,7 @@ def _run_subprocess( engine.shutdown() print(f"subprocess[{tp_rank=}] end", flush=True) - + # Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py def _get_fsdp_state_dict(hf_model, tp_size: int): device_mesh = init_device_mesh( @@ -275,5 +284,6 @@ def _get_fsdp_state_dict(hf_model, tp_size: int): return fsdp_model.state_dict() + if __name__ == "__main__": unittest.main() From d13e44dd581bae9add377cf1da428611510de063 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Mon, 7 Apr 2025 19:30:26 +0000 Subject: [PATCH 12/25] Update comments and CI testing logic --- test/srt/test_verl_engine_server.py | 147 +++++++++++++++++++++++++--- 1 file changed, 135 insertions(+), 12 deletions(-) diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 3cb2d737562..58755e99bc7 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -71,10 +71,18 @@ # ), ] -# The port for the server -# 2345 is the default port for the server +# This port is used for HTTP API communication with the VerlEngine server +# It handles client requests for text generation, weight updates, and memory management +# This port must be available and not used by other processes PORT = find_available_port(2345) +# Master port is used for PyTorch's distributed communication setup +# It enables tensor-parallel processes to communicate with each other +# Default is 23456, but we find an available port dynamically in assert_fragment_e2e_execution +# This port is critical for torch.distributed.init_process_group to function properly +# Each test needs a unique master_port to avoid conflicts between parallel test executions +# master_port = find_available_port(23456) # This is set in assert_fragment_e2e_execution method + class TestVerlEngine(CustomTestCase): @classmethod @@ -91,6 +99,27 @@ def assert_fragment_e2e_execution( prefill_tolerance: float = 0.1, decode_tolerance: float = 0.1, ): + """ + Tests VerlEngine with tensor parallelism across multiple processes. + + Spawns tp_size processes to test distributed execution, including: + - Model inference via direct API and HTTP server + - Weight updating functionality + - Memory management (release/resume) + + The test validates output correctness against a reference implementation + within specified tolerance bounds. + + Parameters: + ----------- + index: int - Test index for logging + model_path: str - HuggingFace model identifier + mem_fraction_static: float - Memory fraction for static tensors + tp_size: int - Number of tensor parallel processes + tight_memory: bool - Enable memory optimization + prefill_tolerance: float - Max error for prefill computation + decode_tolerance: float - Max error for decoding computation + """ master_port = find_available_port(23456) @@ -126,11 +155,25 @@ def assert_fragment_e2e_execution( p.join() def test_models(self): + """ + Orchestrates end-to-end testing across configured model sets. + + In CI environments: Randomly selects one model for faster testing. + In development: Tests all configured models for comprehensive validation. + + Each model configuration specifies model path, memory settings, + tensor-parallel size, and error tolerance bounds. + """ if is_in_ci(): - for index, model_info in enumerate(CI_MODELS): - self.assert_fragment_e2e_execution(index=index, **model_info) + # Randomly select one model in CI for faster testing + if CI_MODELS: # Make sure list is not empty + model_info = random.choice(CI_MODELS) + print(f"CI environment: Testing randomly selected model: {model_info['model_path']}") + self.assert_fragment_e2e_execution(index=0, **model_info) return + # Test all models in development environment + print(f"Development environment: Testing all {len(ALL_OTHER_MODELS)} models") for index, model_info in enumerate(ALL_OTHER_MODELS): self.assert_fragment_e2e_execution(index=index, **model_info) @@ -146,6 +189,28 @@ def _run_subprocess( prefill_tolerance: float, decode_tolerance: float, ): + """ + Executes a single tensor-parallel process for testing VerlEngine. + + Performs the core test operations: + 1. Initializes distributed environment + 2. Loads HuggingFace model for reference + 3. Tests VerlEngine API (generation, memory management, weight updates) + 4. Tests OpenAI-compatible endpoints on rank 0 + + Reports success/failure via output_writer pipe. + + Parameters: + tp_rank: int - Process rank in tensor parallel group + tp_size: int - Total processes in tensor parallel group + master_port: int - Port for distributed communication + output_writer - Pipe for result communication + model_path: str - HuggingFace model identifier + mem_fraction_static: float - Static memory allocation fraction + tight_memory: bool - Memory optimization flag + prefill_tolerance: float - Acceptable prefill error + decode_tolerance: float - Acceptable decode error + """ try: print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}") @@ -157,8 +222,15 @@ def _run_subprocess( mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"]) inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs) inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs) + # Print basic information about this subprocess including: + # - Current tensor-parallel rank + # - Device mesh configuration for both CUDA and CPU + # - This subprocess's role in testing tensor-parallel execution + # - How it contributes to the distributed model testing print( - f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}" + f"subprocess[{tp_rank=}] initialized for VerlEngine testing - " + f"Role: Shard {tp_rank+1}/{tp_size} of tensor-parallel model execution | " + f"Device meshes: CUDA={inference_device_mesh_device}, CPU={inference_device_mesh_cpu}" ) # hf model is used for comparison @@ -199,13 +271,33 @@ def _run_subprocess( enable_memory_saver=True, port=PORT, ) - # test direct generate API - print(f"subprocess[{tp_rank=}] testing direct generate API") + # test direct generate API with multiple different requests + print(f"subprocess[{tp_rank=}] testing direct generate API with multiple requests") + + # Request 1: Basic generation with temperature + print(f"subprocess[{tp_rank=}] test request 1: Basic generation") direct_response = engine.generate( prompt="Hello, world!", sampling_params={"temperature": 0.7, "max_new_tokens": 20}, ) - print(f"Direct generate response: {direct_response}") + print(f"Response 1: {direct_response}") + + # Request 2: Zero temperature (greedy) generation + print(f"subprocess[{tp_rank=}] test request 2: Greedy generation") + direct_response = engine.generate( + prompt="Complete this sequence: 1, 2, 3,", + sampling_params={"temperature": 0.0, "max_new_tokens": 10}, + ) + print(f"Response 2: {direct_response}") + + # Request 3: Batch generation + print(f"subprocess[{tp_rank=}] test request 3: Batch generation") + batch_response = engine.generate( + prompt=["Translate 'hello' to French:", "Translate 'goodbye' to Spanish:"], + sampling_params={"temperature": 0.8, "max_new_tokens": 15}, + ) + print(f"Response 3: {batch_response}") + # test memory occupation APIs print(f"subprocess[{tp_rank=}] testing memory occupation APIs") @@ -220,10 +312,26 @@ def _run_subprocess( if tp_rank == 0: client = OpenAI(api_key="None", base_url=f"http://localhost:{PORT}/v1") print(client.models.list().data[0].id) + + # Multiple HTTP API requests + print("Testing HTTP API with multiple requests") + + # Request 1 url = f"http://localhost:{PORT}/generate" data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} response = requests.post(url, json=data) - print(response.json()) + print(f"HTTP Response 1: {response.json()}") + + # Request 2 + data = {"text": "The capital of France is", "sampling_params": {"temperature": 0.2}} + response = requests.post(url, json=data) + print(f"HTTP Response 2: {response.json()}") + + # Request 3 + data = {"text": "List three colors:", "sampling_params": {"top_p": 0.95, "max_new_tokens": 25}} + response = requests.post(url, json=data) + print(f"HTTP Response 3: {response.json()}") + if _ENABLE_UPDATE_WEIGHTS: print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) @@ -231,12 +339,13 @@ def _run_subprocess( [(k, v) for k, v in fsdp_state_dict.items()] ) - print(f"subprocess[{tp_rank=}] testing direct generate API") + # Final generation test after weight update + print(f"subprocess[{tp_rank=}] testing generation after weight update") direct_response = engine.generate( - prompt="Hello, world!", + prompt="After weight update: Hello, world!", sampling_params={"temperature": 0.7, "max_new_tokens": 20}, ) - print(f"Direct generate response: {direct_response}") + print(f"Post-update response: {direct_response}") execution_ok = True @@ -254,6 +363,20 @@ def _run_subprocess( # Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py def _get_fsdp_state_dict(hf_model, tp_size: int): + """ + Creates a sharded state dictionary for weight update testing. + + Wraps the HuggingFace model with FSDP (FullyShardedDataParallel), + configures precision settings, and returns a sharded state dict + for testing VerlEngine's weight update capabilities. + + Parameters: + hf_model - HuggingFace model to wrap + tp_size: int - Number of tensor-parallel shards + + Returns: + dict - Sharded state dict for update_weights_from_tensor + """ device_mesh = init_device_mesh( "cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"] ) From e10dea9a917edee554c494f2d24078f3a8218a51 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Mon, 7 Apr 2025 20:01:39 +0000 Subject: [PATCH 13/25] Fix lint check --- test/srt/test_verl_engine_server.py | 57 +++++++++++++++++------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 58755e99bc7..97db0c6e1cd 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -101,15 +101,15 @@ def assert_fragment_e2e_execution( ): """ Tests VerlEngine with tensor parallelism across multiple processes. - + Spawns tp_size processes to test distributed execution, including: - Model inference via direct API and HTTP server - Weight updating functionality - Memory management (release/resume) - + The test validates output correctness against a reference implementation within specified tolerance bounds. - + Parameters: ----------- index: int - Test index for logging @@ -157,10 +157,10 @@ def assert_fragment_e2e_execution( def test_models(self): """ Orchestrates end-to-end testing across configured model sets. - + In CI environments: Randomly selects one model for faster testing. In development: Tests all configured models for comprehensive validation. - + Each model configuration specifies model path, memory settings, tensor-parallel size, and error tolerance bounds. """ @@ -168,7 +168,9 @@ def test_models(self): # Randomly select one model in CI for faster testing if CI_MODELS: # Make sure list is not empty model_info = random.choice(CI_MODELS) - print(f"CI environment: Testing randomly selected model: {model_info['model_path']}") + print( + f"CI environment: Testing randomly selected model: {model_info['model_path']}" + ) self.assert_fragment_e2e_execution(index=0, **model_info) return @@ -191,15 +193,15 @@ def _run_subprocess( ): """ Executes a single tensor-parallel process for testing VerlEngine. - + Performs the core test operations: 1. Initializes distributed environment 2. Loads HuggingFace model for reference 3. Tests VerlEngine API (generation, memory management, weight updates) 4. Tests OpenAI-compatible endpoints on rank 0 - + Reports success/failure via output_writer pipe. - + Parameters: tp_rank: int - Process rank in tensor parallel group tp_size: int - Total processes in tensor parallel group @@ -272,8 +274,10 @@ def _run_subprocess( port=PORT, ) # test direct generate API with multiple different requests - print(f"subprocess[{tp_rank=}] testing direct generate API with multiple requests") - + print( + f"subprocess[{tp_rank=}] testing direct generate API with multiple requests" + ) + # Request 1: Basic generation with temperature print(f"subprocess[{tp_rank=}] test request 1: Basic generation") direct_response = engine.generate( @@ -281,7 +285,7 @@ def _run_subprocess( sampling_params={"temperature": 0.7, "max_new_tokens": 20}, ) print(f"Response 1: {direct_response}") - + # Request 2: Zero temperature (greedy) generation print(f"subprocess[{tp_rank=}] test request 2: Greedy generation") direct_response = engine.generate( @@ -289,7 +293,7 @@ def _run_subprocess( sampling_params={"temperature": 0.0, "max_new_tokens": 10}, ) print(f"Response 2: {direct_response}") - + # Request 3: Batch generation print(f"subprocess[{tp_rank=}] test request 3: Batch generation") batch_response = engine.generate( @@ -298,7 +302,6 @@ def _run_subprocess( ) print(f"Response 3: {batch_response}") - # test memory occupation APIs print(f"subprocess[{tp_rank=}] testing memory occupation APIs") engine.release_memory_occupation() @@ -312,26 +315,32 @@ def _run_subprocess( if tp_rank == 0: client = OpenAI(api_key="None", base_url=f"http://localhost:{PORT}/v1") print(client.models.list().data[0].id) - + # Multiple HTTP API requests print("Testing HTTP API with multiple requests") - + # Request 1 url = f"http://localhost:{PORT}/generate" data = {"text": "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="} response = requests.post(url, json=data) print(f"HTTP Response 1: {response.json()}") - + # Request 2 - data = {"text": "The capital of France is", "sampling_params": {"temperature": 0.2}} + data = { + "text": "The capital of France is", + "sampling_params": {"temperature": 0.2}, + } response = requests.post(url, json=data) print(f"HTTP Response 2: {response.json()}") - + # Request 3 - data = {"text": "List three colors:", "sampling_params": {"top_p": 0.95, "max_new_tokens": 25}} + data = { + "text": "List three colors:", + "sampling_params": {"top_p": 0.95, "max_new_tokens": 25}, + } response = requests.post(url, json=data) print(f"HTTP Response 3: {response.json()}") - + if _ENABLE_UPDATE_WEIGHTS: print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True) @@ -365,15 +374,15 @@ def _run_subprocess( def _get_fsdp_state_dict(hf_model, tp_size: int): """ Creates a sharded state dictionary for weight update testing. - + Wraps the HuggingFace model with FSDP (FullyShardedDataParallel), configures precision settings, and returns a sharded state dict for testing VerlEngine's weight update capabilities. - + Parameters: hf_model - HuggingFace model to wrap tp_size: int - Number of tensor-parallel shards - + Returns: dict - Sharded state dict for update_weights_from_tensor """ From 6877dbcc01586a4a15dc2e9db0cd2c356c8391d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Tue, 8 Apr 2025 03:28:02 +0000 Subject: [PATCH 14/25] revise CI testing logic --- test/srt/test_verl_engine_server.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 97db0c6e1cd..103a97b2176 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -46,7 +46,7 @@ # dict(model_path="google/gemma-2-2b"), ] ALL_OTHER_MODELS = [ - # dict(model_path="meta-llama/Llama-3.2-1B-Instruct"), + dict(model_path="meta-llama/Llama-3.2-1B-Instruct", tp_size=1), dict(model_path="Qwen/Qwen2-1.5B"), dict( model_path="Qwen/Qwen2.5-14B-Instruct", @@ -164,19 +164,13 @@ def test_models(self): Each model configuration specifies model path, memory settings, tensor-parallel size, and error tolerance bounds. """ + test_models = ALL_OTHER_MODELS if is_in_ci(): # Randomly select one model in CI for faster testing - if CI_MODELS: # Make sure list is not empty - model_info = random.choice(CI_MODELS) - print( - f"CI environment: Testing randomly selected model: {model_info['model_path']}" - ) - self.assert_fragment_e2e_execution(index=0, **model_info) - return - + test_models = [random.choice(ALL_OTHER_MODELS)] # Test all models in development environment print(f"Development environment: Testing all {len(ALL_OTHER_MODELS)} models") - for index, model_info in enumerate(ALL_OTHER_MODELS): + for index, model_info in enumerate(test_models): self.assert_fragment_e2e_execution(index=index, **model_info) From 58cf1805f5b3236513857a7e8c3ede4b7262dea4 Mon Sep 17 00:00:00 2001 From: Chayenne Date: Mon, 7 Apr 2025 21:08:03 -0700 Subject: [PATCH 15/25] Update base_engine.py --- python/sglang/srt/entrypoints/base_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/base_engine.py b/python/sglang/srt/entrypoints/base_engine.py index 3bf3a86c633..894a6f06519 100644 --- a/python/sglang/srt/entrypoints/base_engine.py +++ b/python/sglang/srt/entrypoints/base_engine.py @@ -24,7 +24,7 @@ def generate( lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, ) -> Union[Dict, Iterator[Dict]]: - """Generate text based on given inputs.""" + """Generate outputs based on given inputs.""" pass @abstractmethod @@ -44,7 +44,7 @@ def release_memory_occupation(self): @abstractmethod def resume_memory_occupation(self): - """Resume GPU memory occupation previously released.""" + """Resume GPU memory occupation which is previously released.""" pass @abstractmethod From 66d236b4af19c51ec21c52856b5d205557de8b8b Mon Sep 17 00:00:00 2001 From: Chayenne Date: Mon, 7 Apr 2025 21:12:00 -0700 Subject: [PATCH 16/25] Update http_server_engine.py --- python/sglang/srt/entrypoints/http_server_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index bb0118f01cd..66082c7f370 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -59,7 +59,6 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: class HttpServerEngineForRL(EngineBase): def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) - # self.server_args = copy.deepcopy(server_args) print(f"launch_server_from_verl_engine {self.server_args.port}") self.process = launch_server_process(self.server_args) From e97cd3ff59c09089a322e4183f023287447996ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Tue, 8 Apr 2025 04:34:32 +0000 Subject: [PATCH 17/25] revise some expression in http_server_engine --- python/sglang/srt/entrypoints/http_server_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 66082c7f370..0431127d756 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -73,7 +73,7 @@ def update_weights_from_tensor( flush_cache: bool = False, ): """ - Update model weights from tensor data. + Update model weights from tensor data. The HTTPS server will only post meta data, and the real weights will be copied directly from GPUs. Note: The model should be on GPUs rather than CPU for this functionality to work properly. If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. From db3937e71b1c681ba767187498293fe8fddc9221 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Wed, 9 Apr 2025 01:22:51 +0000 Subject: [PATCH 18/25] Refactoring Code Structure --- .../{base_engine.py => EngineBase.py} | 0 python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/http_server.py | 20 ++----- .../srt/entrypoints/http_server_engine.py | 55 ++++++++--------- python/sglang/srt/entrypoints/verl_engine.py | 15 ++--- python/sglang/srt/managers/io_struct.py | 19 ++++-- python/sglang/srt/utils.py | 59 +++++++++++++++++-- 7 files changed, 108 insertions(+), 62 deletions(-) rename python/sglang/srt/entrypoints/{base_engine.py => EngineBase.py} (100%) diff --git a/python/sglang/srt/entrypoints/base_engine.py b/python/sglang/srt/entrypoints/EngineBase.py similarity index 100% rename from python/sglang/srt/entrypoints/base_engine.py rename to python/sglang/srt/entrypoints/EngineBase.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index bf9ca49e5c4..b0635b99110 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -317,7 +317,7 @@ def update_weights_from_tensor( to avoid duplicated operations such as clearing cache.""" obj = UpdateWeightsFromTensorReqInput( serialized_named_tensors=[ - MultiprocessingSerializer.serialize(named_tensors) + HttpSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) ], load_format=load_format, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 749de5de49e..1004150763b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -31,8 +31,6 @@ # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) -import base64 -import pickle from contextlib import asynccontextmanager import numpy as np @@ -417,25 +415,17 @@ async def init_weights_update_group( @app.post("/update_weights_from_tensor") async def update_weights_from_tensor( - obj: Union[UpdateWeightsFromTensorReqInput, str], request: Request + obj: UpdateWeightsFromTensorReqInput, request: Request ): """Update the weights from tensor inplace without re-launching the server. Notes: 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. 2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. + 3. Any binary data in the named tensors should be base64 encoded. """ - if isinstance(obj, str): - try: - obj = HttpSerializer.deserialize(obj) - except Exception as e: - return ORJSONResponse( - {"success": False, "message": f"Failed to decode input: {str(e)}"}, - status_code=HTTPStatus.BAD_REQUEST, - ) - else: - obj.serialized_named_tensors = [ - HttpSerializer.deserialize(item) for item in obj.serialized_named_tensors - ] + obj.serialized_named_tensors = [ + HttpSerializer.deserialize(item) for item in obj.serialized_named_tensors + ] success, message = await _global_state.tokenizer_manager.update_weights_from_tensor( obj, request diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 0431127d756..a35f20a842e 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -14,24 +14,16 @@ from sglang.srt.entrypoints.base_engine import EngineBase from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - HttpSerializer, - MultiprocessingSerializer, - kill_process_tree, -) - - -def launch_server_worker(server_args: ServerArgs): - launch_server(server_args) +from sglang.srt.utils import HttpSerializer, kill_process_tree def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: - p = multiprocessing.Process(target=launch_server_worker, args=(server_args,)) + p = multiprocessing.Process(target=launch_server, args=(server_args,)) p.start() base_url = server_args.url() - timeout = 180.0 + timeout = 300.0 # Increased timeout to 5 minutes for downloading large models start_time = time.time() with requests.Session() as session: @@ -62,9 +54,20 @@ def __init__(self, **kwargs): print(f"launch_server_from_verl_engine {self.server_args.port}") self.process = launch_server_process(self.server_args) - def _url(self, path: str) -> str: - """Construct full URL for server endpoint.""" - return f"http://{self.server_args.host}:{self.server_args.port}/{path}" + def _make_request(self, endpoint: str, payload: dict = None): + """Make a POST request to the specified endpoint with the given payload. + + Args: + endpoint: The API endpoint to call + payload: The JSON payload to send (default: empty dict) + + Returns: + The JSON response from the server + """ + url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" + response = requests.post(url, json=payload or {}) + response.raise_for_status() + return response.json() def update_weights_from_tensor( self, @@ -80,13 +83,11 @@ def update_weights_from_tensor( """ print(f"update_weights_from_tensor of HttpServerEngineForRL") - serialized_named_tensors = HttpSerializer.serialize( - MultiprocessingSerializer.serialize(named_tensors) - ) + serialized_named_tensors = HttpSerializer.serialize(named_tensors) - response = requests.post( - self._url("update_weights_from_tensor"), - json={ + return self._make_request( + "update_weights_from_tensor", + { "serialized_named_tensors": [ serialized_named_tensors for _ in range(self.server_args.tp_size) ], @@ -94,8 +95,6 @@ def update_weights_from_tensor( "flush_cache": flush_cache, }, ) - response.raise_for_status() - return response.json() def shutdown(self): kill_process_tree(self.process.pid) @@ -128,16 +127,10 @@ def generate( # Filter out None values payload = {k: v for k, v in payload.items() if v is not None} - response = requests.post(self._url("generate"), json=payload) - response.raise_for_status() - return response.json() + return self._make_request("generate", payload) def release_memory_occupation(self): - response = requests.post(self._url("release_memory_occupation"), json={}) - response.raise_for_status() - return response.json() + return self._make_request("release_memory_occupation") def resume_memory_occupation(self): - response = requests.post(self._url("resume_memory_occupation"), json={}) - response.raise_for_status() - return response.json() + return self._make_request("resume_memory_occupation") diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index 966f843cc98..5f3101c36ba 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -18,7 +18,7 @@ import torch.distributed as dist from torch.distributed.tensor import DeviceMesh, DTensor -from sglang.srt.entrypoints.http_server_engine import HttpServerEngineForRL +from sglang.srt.entrypoints.http_server_engine import HttpServerEngine from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine @@ -42,20 +42,21 @@ def __init__( node_rank = self._tp_rank // tp_size_per_node first_rank_in_node = self._tp_rank % tp_size_per_node == 0 + # Common engine keyword arguments + engine_kwargs = dict( + **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes + ) + if backend == "engine": if first_rank_in_node: os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine( - **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes - ) + self._engine = Engine(**engine_kwargs) else: self._engine = None elif backend == "server": if self._tp_rank == 0: - self._engine = HttpServerEngineForRL( - **kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes - ) + self._engine = HttpServerEngine(**engine_kwargs) else: self._engine = None else: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e25a8f242c5..2c3d054451f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -550,10 +550,21 @@ class UpdateWeightsFromDistributedReqOutput: @dataclass class UpdateWeightsFromTensorReqInput: - # List containing one serialized Dict[str, torch.Tensor] per TP worker - serialized_named_tensors: List[bytes] - load_format: Optional[str] - flush_cache: bool + """Update model weights from tensor input. + + The serialization of tensor data uses HttpSerializer: + - Binary data like tensors are base64 encoded + - Data is structured in JSON for easy transmission over HTTP + - No pickle serialization is used for security reasons + """ + + # List containing serialized Dict[str, torch.Tensor] data for each TP worker + # Each item is serialized using HttpSerializer.serialize() + serialized_named_tensors: List[str] + # Optional format specification for loading + load_format: Optional[str] = None + # Whether to flush the cache after updating weights + flush_cache: bool = True @dataclass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a6aa1e732ff..d3c97ee6aaa 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1494,13 +1494,64 @@ def deserialize(data): class HttpSerializer: @staticmethod def serialize(data): - pickled = pickle.dumps(data) - return base64.b64encode(pickled).decode("utf-8") + """Serialize data for HTTP transmission. + If data is bytes, it encodes it directly with base64. + For complex objects, it serializes them to JSON with base64 encoding for any binary data. + """ + if isinstance(data, bytes): + # Directly encode bytes with base64 + return base64.b64encode(data).decode("utf-8") + + # For other types, convert to JSON-serializable format with base64 for any binary data + def encode_binary(obj): + if isinstance(obj, bytes): + # For bytes, encode as base64 and mark as binary + return {"__binary__": base64.b64encode(obj).decode("utf-8")} + elif isinstance(obj, dict): + # Process dictionaries recursively + return {k: encode_binary(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + # Process lists and tuples recursively + return [encode_binary(item) for item in obj] + # Other types are returned as is (assuming they're JSON serializable) + return obj + + # Encode the data and serialize to JSON + encoded_data = encode_binary(data) + return json.dumps(encoded_data) @staticmethod def deserialize(data): - pickled = base64.b64decode(data) - return pickle.loads(pickled) + """Deserialize data from HTTP transmission. + First tries to parse as JSON, if that fails, assumes it's base64-encoded bytes. + """ + try: + # Try to parse as JSON + json_data = json.loads(data) + + # Function to decode binary data recursively + def decode_binary(obj): + if isinstance(obj, dict): + if "__binary__" in obj and len(obj) == 1: + # This is a binary object encoded with base64 + return base64.b64decode(obj["__binary__"]) + # Process dictionaries recursively + return {k: decode_binary(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Process lists recursively + return [decode_binary(item) for item in obj] + # Other types are returned as is + return obj + + return decode_binary(json_data) + except json.JSONDecodeError: + # If not JSON, assume it's base64-encoded bytes + try: + return base64.b64decode(data) + except: + raise ValueError( + f"Failed to deserialize data: not valid JSON or base64" + ) def debug_timing(func): From c266d4aa7a744cae66ab1347900ca1e89f662383 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Wed, 9 Apr 2025 04:20:20 +0000 Subject: [PATCH 19/25] For Sync --- python/sglang/srt/entrypoints/verl_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index be7f27e1627..ec8c1c1823b 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -19,7 +19,7 @@ from PIL.Image import Image from torch.distributed.tensor import DeviceMesh, DTensor -from sglang.srt.entrypoints.http_server_engine import HttpServerEngine +from sglang.srt.entrypoints.http_server_engine import HttpServerEngineForRL from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine @@ -57,7 +57,7 @@ def __init__( elif backend == "server": if self._tp_rank == 0: - self._engine = HttpServerEngine(**engine_kwargs) + self._engine = HttpServerEngineForRL(**engine_kwargs) else: self._engine = None else: From dca2e96aaa069d71fa07d9ce2e5d47a12a95b76a Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Wed, 9 Apr 2025 04:38:57 +0000 Subject: [PATCH 20/25] Revert MP in Engine --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/http_server_engine.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0c3fe84b217..9ef2ba53edd 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -344,7 +344,7 @@ def update_weights_from_tensor( to avoid duplicated operations such as clearing cache.""" obj = UpdateWeightsFromTensorReqInput( serialized_named_tensors=[ - HttpSerializer.serialize(named_tensors) + MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) ], load_format=load_format, diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index a35f20a842e..a16ddc6fca6 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -14,7 +14,11 @@ from sglang.srt.entrypoints.base_engine import EngineBase from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import HttpSerializer, kill_process_tree +from sglang.srt.utils import ( + HttpSerializer, + MultiprocessingSerializer, + kill_process_tree, +) def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: From d38ea8dd265ad98e89df1cb90263d525564f7c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:40:22 +0000 Subject: [PATCH 21/25] update method of updating weights --- python/sglang/srt/entrypoints/engine.py | 2 +- python/sglang/srt/entrypoints/http_server.py | 8 +- .../srt/entrypoints/http_server_engine.py | 12 +-- python/sglang/srt/managers/io_struct.py | 3 - python/sglang/srt/utils.py | 90 ++++++------------- test/srt/test_verl_engine_server.py | 14 +-- 6 files changed, 44 insertions(+), 85 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 9ef2ba53edd..9738e466ce6 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -38,7 +38,7 @@ import uvloop from sglang.srt.code_completion_parser import load_completion_template_for_openai_api -from sglang.srt.entrypoints.base_engine import EngineBase +from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 1004150763b..b0537c8e617 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -25,9 +25,12 @@ import os import threading import time +from ast import Mult from http import HTTPStatus from typing import AsyncIterator, Callable, Dict, Optional, Union +from sglang.srt.model_executor.model_runner import LocalSerializedTensor + # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -81,7 +84,7 @@ from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - HttpSerializer, + MultiprocessingSerializer, add_api_key_middleware, add_prometheus_middleware, delete_directory, @@ -423,9 +426,6 @@ async def update_weights_from_tensor( 2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. 3. Any binary data in the named tensors should be base64 encoded. """ - obj.serialized_named_tensors = [ - HttpSerializer.deserialize(item) for item in obj.serialized_named_tensors - ] success, message = await _global_state.tokenizer_manager.update_weights_from_tensor( obj, request diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index a16ddc6fca6..e5bcbda79b5 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -11,14 +11,10 @@ import torch import torch.distributed as dist -from sglang.srt.entrypoints.base_engine import EngineBase +from sglang.srt.entrypoints.EngineBase import EngineBase from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - HttpSerializer, - MultiprocessingSerializer, - kill_process_tree, -) +from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: @@ -87,13 +83,13 @@ def update_weights_from_tensor( """ print(f"update_weights_from_tensor of HttpServerEngineForRL") - serialized_named_tensors = HttpSerializer.serialize(named_tensors) return self._make_request( "update_weights_from_tensor", { "serialized_named_tensors": [ - serialized_named_tensors for _ in range(self.server_args.tp_size) + MultiprocessingSerializer.serialize(named_tensors, output_str=True) + for _ in range(self.server_args.tp_size) ], "load_format": load_format, "flush_cache": flush_cache, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c018adc24f3..2144584d4b9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -694,14 +694,11 @@ class UpdateWeightsFromDistributedReqOutput: class UpdateWeightsFromTensorReqInput: """Update model weights from tensor input. - The serialization of tensor data uses HttpSerializer: - Binary data like tensors are base64 encoded - Data is structured in JSON for easy transmission over HTTP - No pickle serialization is used for security reasons """ - # List containing serialized Dict[str, torch.Tensor] data for each TP worker - # Each item is serialized using HttpSerializer.serialize() serialized_named_tensors: List[str] # Optional format specification for loading load_format: Optional[str] = None diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d3c97ee6aaa..ceac5040420 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1480,78 +1480,44 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor: class MultiprocessingSerializer: @staticmethod - def serialize(obj): + def serialize(obj, output_str: bool = False): + """ + Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ buf = io.BytesIO() ForkingPickler(buf).dump(obj) buf.seek(0) - return buf.read() + output = buf.read() - @staticmethod - def deserialize(data): - return ForkingPickler.loads(data) + if output_str: + # Convert bytes to base64-encoded string + output = base64.b64encode(output).decode("utf-8") + return output -class HttpSerializer: @staticmethod - def serialize(data): - """Serialize data for HTTP transmission. - If data is bytes, it encodes it directly with base64. - For complex objects, it serializes them to JSON with base64 encoding for any binary data. + def deserialize(data): """ - if isinstance(data, bytes): - # Directly encode bytes with base64 - return base64.b64encode(data).decode("utf-8") - - # For other types, convert to JSON-serializable format with base64 for any binary data - def encode_binary(obj): - if isinstance(obj, bytes): - # For bytes, encode as base64 and mark as binary - return {"__binary__": base64.b64encode(obj).decode("utf-8")} - elif isinstance(obj, dict): - # Process dictionaries recursively - return {k: encode_binary(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): - # Process lists and tuples recursively - return [encode_binary(item) for item in obj] - # Other types are returned as is (assuming they're JSON serializable) - return obj - - # Encode the data and serialize to JSON - encoded_data = encode_binary(data) - return json.dumps(encoded_data) + Deserialize a previously serialized object. - @staticmethod - def deserialize(data): - """Deserialize data from HTTP transmission. - First tries to parse as JSON, if that fails, assumes it's base64-encoded bytes. + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. """ - try: - # Try to parse as JSON - json_data = json.loads(data) - - # Function to decode binary data recursively - def decode_binary(obj): - if isinstance(obj, dict): - if "__binary__" in obj and len(obj) == 1: - # This is a binary object encoded with base64 - return base64.b64decode(obj["__binary__"]) - # Process dictionaries recursively - return {k: decode_binary(v) for k, v in obj.items()} - elif isinstance(obj, list): - # Process lists recursively - return [decode_binary(item) for item in obj] - # Other types are returned as is - return obj - - return decode_binary(json_data) - except json.JSONDecodeError: - # If not JSON, assume it's base64-encoded bytes - try: - return base64.b64decode(data) - except: - raise ValueError( - f"Failed to deserialize data: not valid JSON or base64" - ) + if isinstance(data, str): + # Decode base64 string to bytes + data = base64.b64decode(data) + + return ForkingPickler.loads(data) def debug_timing(func): diff --git a/test/srt/test_verl_engine_server.py b/test/srt/test_verl_engine_server.py index 103a97b2176..6b7cbd0bf6f 100644 --- a/test/srt/test_verl_engine_server.py +++ b/test/srt/test_verl_engine_server.py @@ -48,13 +48,13 @@ ALL_OTHER_MODELS = [ dict(model_path="meta-llama/Llama-3.2-1B-Instruct", tp_size=1), dict(model_path="Qwen/Qwen2-1.5B"), - dict( - model_path="Qwen/Qwen2.5-14B-Instruct", - mem_fraction_static=0.4, - tp_size=8, - tight_memory=True, - decode_tolerance=1.3, - ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error + # dict( + # model_path="Qwen/Qwen2.5-14B-Instruct", + # mem_fraction_static=0.4, + # tp_size=8, + # tight_memory=True, + # decode_tolerance=1.3, + # ), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3), # dict(model_path="allenai/OLMo-1B-0724-hf"), # dict( From 99dcc145f82b96ad9b9c20f7ed978ca6c6b86185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Thu, 10 Apr 2025 01:43:50 +0000 Subject: [PATCH 22/25] update name --- python/sglang/srt/entrypoints/http_server_engine.py | 10 +++++++--- python/sglang/srt/entrypoints/verl_engine.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index e5bcbda79b5..d7c8e8adbf9 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -48,7 +48,13 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: raise TimeoutError("Server failed to start within the timeout period.") -class HttpServerEngineForRL(EngineBase): +class HttpServerEngineAdapter(EngineBase): + """ + You can use this class to launch a server from a VerlEngine instance. + We recommend using this class only you need to use http server. + Otherwise, you can use Engine directly. + """ + def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) print(f"launch_server_from_verl_engine {self.server_args.port}") @@ -82,8 +88,6 @@ def update_weights_from_tensor( If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. """ - print(f"update_weights_from_tensor of HttpServerEngineForRL") - return self._make_request( "update_weights_from_tensor", { diff --git a/python/sglang/srt/entrypoints/verl_engine.py b/python/sglang/srt/entrypoints/verl_engine.py index ec8c1c1823b..d49392f4c3d 100644 --- a/python/sglang/srt/entrypoints/verl_engine.py +++ b/python/sglang/srt/entrypoints/verl_engine.py @@ -19,7 +19,7 @@ from PIL.Image import Image from torch.distributed.tensor import DeviceMesh, DTensor -from sglang.srt.entrypoints.http_server_engine import HttpServerEngineForRL +from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.server import Engine @@ -57,7 +57,7 @@ def __init__( elif backend == "server": if self._tp_rank == 0: - self._engine = HttpServerEngineForRL(**engine_kwargs) + self._engine = HttpServerEngineAdapter(**engine_kwargs) else: self._engine = None else: From ae2130b0ea2743d1accf625c3e59173d16f3265d Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Fri, 11 Apr 2025 00:28:56 +0000 Subject: [PATCH 23/25] Quick fix for review --- python/sglang/srt/entrypoints/http_server_engine.py | 4 ++-- python/sglang/srt/managers/io_struct.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index d7c8e8adbf9..32c8846b8f4 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -60,7 +60,7 @@ def __init__(self, **kwargs): print(f"launch_server_from_verl_engine {self.server_args.port}") self.process = launch_server_process(self.server_args) - def _make_request(self, endpoint: str, payload: dict = None): + def _make_request(self, endpoint: str, payload: Optional[dict] = None): """Make a POST request to the specified endpoint with the given payload. Args: @@ -82,7 +82,7 @@ def update_weights_from_tensor( flush_cache: bool = False, ): """ - Update model weights from tensor data. The HTTPS server will only post meta data, and the real weights will be copied directly from GPUs. + Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. Note: The model should be on GPUs rather than CPU for this functionality to work properly. If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6ef3c39922d..e2c3f09f3e5 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -702,12 +702,11 @@ class UpdateWeightsFromDistributedReqOutput: class UpdateWeightsFromTensorReqInput: """Update model weights from tensor input. - - Binary data like tensors are base64 encoded + - Tensors are serialized for transmission - Data is structured in JSON for easy transmission over HTTP - - No pickle serialization is used for security reasons """ - serialized_named_tensors: List[str] + serialized_named_tensors: List[Union[str, bytes]] # Optional format specification for loading load_format: Optional[str] = None # Whether to flush the cache after updating weights From 128def04d60933489fdf96e080a4958ac1771d84 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Fri, 11 Apr 2025 00:32:46 +0000 Subject: [PATCH 24/25] One other HTTP clarification --- python/sglang/srt/entrypoints/http_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index b0537c8e617..c036037c361 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -423,7 +423,7 @@ async def update_weights_from_tensor( """Update the weights from tensor inplace without re-launching the server. Notes: 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors. - 2. HTTPS will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. + 2. HTTP will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model. 3. Any binary data in the named tensors should be base64 encoded. """ From 59992dfebb8a26a3dfdd4c36f865c95c6673b2a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E5=A4=A9=E8=8E=B2?= <91449279+yitianlian@users.noreply.github.com> Date: Fri, 11 Apr 2025 02:38:51 +0000 Subject: [PATCH 25/25] update doc --- python/sglang/srt/entrypoints/http_server_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server_engine.py b/python/sglang/srt/entrypoints/http_server_engine.py index 32c8846b8f4..f4d81a417d7 100644 --- a/python/sglang/srt/entrypoints/http_server_engine.py +++ b/python/sglang/srt/entrypoints/http_server_engine.py @@ -57,7 +57,9 @@ class HttpServerEngineAdapter(EngineBase): def __init__(self, **kwargs): self.server_args = ServerArgs(**kwargs) - print(f"launch_server_from_verl_engine {self.server_args.port}") + print( + f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}" + ) self.process = launch_server_process(self.server_args) def _make_request(self, endpoint: str, payload: Optional[dict] = None):