Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
22f87aa
Support Multiple Torch Memory Saver for Multi-Stage-Awake
hebiao064 Jun 8, 2025
eead001
rm unnecessary code
hebiao064 Jun 8, 2025
b834ea6
add test
hebiao064 Jun 8, 2025
11e0461
uncomment disable_cuda_graph
hebiao064 Jun 8, 2025
a336515
fix server up issue
hebiao064 Jun 8, 2025
ceb314f
remove debugging code
hebiao064 Jun 8, 2025
7095f02
polish the test
hebiao064 Jun 8, 2025
d4eae34
Address reviewers feedback
hebiao064 Jun 8, 2025
062b349
modify test
hebiao064 Jun 9, 2025
10d9e26
fix test del model issue
hebiao064 Jun 9, 2025
0ea29ca
simplify code
hebiao064 Jun 9, 2025
bf5e4b3
Merge branch 'main' into bhe/support_multiple_tms
zhaochenyang20 Jun 9, 2025
effced3
removing unnecesary code comment
hebiao064 Jun 10, 2025
723609f
upd
hebiao064 Jun 10, 2025
c9d7be8
Merge branch 'bhe/support_multiple_tms' of https://github.com/sgl-pro…
hebiao064 Jun 10, 2025
3ca173f
Tag based Resume
hebiao064 Jun 11, 2025
ff53de8
fix
hebiao064 Jun 11, 2025
3253de6
update tms usage
hebiao064 Jun 15, 2025
fc8b0df
fix comments
hebiao064 Jun 15, 2025
29fab5b
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 15, 2025
b0be907
fix
hebiao064 Jun 16, 2025
33ef688
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 16, 2025
1e04dc5
fix
hebiao064 Jun 16, 2025
0b077ed
remove some comments
hebiao064 Jun 16, 2025
e41a542
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 17, 2025
6cb1ff2
bump tms
hebiao064 Jun 17, 2025
0f35284
Merge branch 'main' into bhe/tag_based_resume
fzyzcjy Jun 17, 2025
bd5306b
fix test
hebiao064 Jun 17, 2025
3e62d19
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 17, 2025
33a897b
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 17, 2025
443fae8
fix
hebiao064 Jun 18, 2025
abfb5db
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 18, 2025
83f9678
fix
hebiao064 Jun 18, 2025
7b19a92
Merge branch 'main' into bhe/tag_based_resume
zhyncs Jun 18, 2025
669637d
fix
hebiao064 Jun 18, 2025
c39a27d
Update test_release_memory_occupation.py
zhaochenyang20 Jun 18, 2025
35adb70
Merge branch 'main' into bhe/tag_based_resume
zhaochenyang20 Jun 19, 2025
4cb9c2e
add tp 2 tests
zhaochen20 Jun 19, 2025
0311c14
fix
hebiao064 Jun 19, 2025
e4a2613
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 19, 2025
14ae324
set to 0.85 for tp2
hebiao064 Jun 19, 2025
aa1e259
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import (
configure_subprocess as tms_configure_subprocess,
)
from sglang.srt.utils import (
MultiprocessingSerializer,
assert_pkg_version,
Expand Down Expand Up @@ -465,17 +467,17 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100):
self.tokenizer_manager.get_weights_by_name(obj, None)
)

def release_memory_occupation(self):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
def release_memory_occupation(self, tags: Optional[List[str]] = None):
"""Pause GPU Memory occupation temporarily for Model Weights and KV Cache."""
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None)
)

def resume_memory_occupation(self):
"""Resume GPU occupation."""
obj = ResumeMemoryOccupationReqInput()
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
"""Resume GPU Memory occupation for Model Weights and KV Cache."""
obj = ResumeMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None)
Expand Down Expand Up @@ -656,11 +658,6 @@ def _launch_subprocesses(

scheduler_procs = []
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)

scheduler_pipe_readers = []

nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
Expand Down Expand Up @@ -696,7 +693,9 @@ def _launch_subprocesses(
writer,
),
)
with memory_saver_adapter.configure_subprocess():

# Preload Torch Memory Saver cpp to make sure all subprocesses can use torch memory saver
with tms_configure_subprocess(enable=server_args.enable_memory_saver):
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
Expand Down
10 changes: 4 additions & 6 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.torch_memory_saver_adapter import (
configure_subprocess as tms_configure_subprocess,
)
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
from sglang.utils import get_exception_traceback

Expand Down Expand Up @@ -177,10 +179,6 @@ def launch_tensor_parallel_group(
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")

memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)

scheduler_pipe_readers = []

nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
Expand Down Expand Up @@ -233,7 +231,7 @@ def launch_tensor_parallel_group(
writer,
),
)
with memory_saver_adapter.configure_subprocess():
with tms_configure_subprocess(enable=server_args.enable_memory_saver):
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ class GetWeightsByNameReqOutput:

@dataclass
class ReleaseMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RLHF
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None


@dataclass
Expand All @@ -804,7 +806,9 @@ class ReleaseMemoryOccupationReqOutput:

@dataclass
class ResumeMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RLHF
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None


@dataclass
Expand Down
60 changes: 42 additions & 18 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,10 @@ def __init__(
self.parent_process = psutil.Process().parent()

# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
from sglang.srt.torch_memory_saver_adapter import torch_memory_saver_adapter

self.memory_saver_adapter = torch_memory_saver_adapter(
server_args.enable_memory_saver
)

# Init profiler
Expand Down Expand Up @@ -1167,7 +1169,7 @@ def log_prefill_stats(
else:
f += f"#queue-req: {len(self.waiting_queue)}"

logger.info(f)
# logger.info(f)

if self.enable_metrics:
cache_hit_rate = adder.log_hit_tokens / (
Expand Down Expand Up @@ -1234,7 +1236,7 @@ def log_decode_stats(
f"#queue-req: {len(self.waiting_queue)}"
)

logger.info(msg)
# logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
Expand Down Expand Up @@ -2121,23 +2123,45 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return GetWeightsByNameReqOutput(parameter)

def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(
caller_name="release_memory_occupation"
)
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
tags = recv_req.tags
import subprocess

if tags is None:
# for backward compatibility, default to release both weights and kv cache
tags = ["weights", "kv_cache"]

# LIFO order: pause kv_cache first, then weights
if "kv_cache" in tags:
self.memory_saver_adapter.pause("kv_cache")
self.flush_cache()

if "weights" in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause("weights")

return ReleaseMemoryOccupationReqOutput()

def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = ["weights", "kv_cache"]

if "weights" in tags:
self.memory_saver_adapter.resume("weights")
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state

if "kv_cache" in tags:
self.memory_saver_adapter.resume("kv_cache")
torch.cuda.synchronize()
time.sleep(3)

torch.distributed.barrier(self.tp_cpu_group)
return ResumeMemoryOccupationReqOutput()

def slow_down(self, recv_req: SlowDownReqInput):
Expand Down
20 changes: 10 additions & 10 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def __init__(
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
from sglang.srt.torch_memory_saver_adapter import torch_memory_saver_adapter

memory_saver_adapter = torch_memory_saver_adapter(enable_memory_saver)

self.size = size
self.max_context_len = max_context_len
self.device = device
with memory_saver_adapter.region():
with memory_saver_adapter.region("kv_cache"):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
Expand Down Expand Up @@ -124,9 +124,9 @@ def __init__(
self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
from sglang.srt.torch_memory_saver_adapter import torch_memory_saver_adapter

self.memory_saver_adapter = torch_memory_saver_adapter(enable_memory_saver)

@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
Expand Down Expand Up @@ -275,7 +275,7 @@ def __init__(
)

def _create_buffers(self):
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region("kv_cache"):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
Expand Down Expand Up @@ -536,7 +536,7 @@ def __init__(
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region("kv_cache"):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
Expand Down Expand Up @@ -671,7 +671,7 @@ def __init__(
end_layer,
)

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region("kv_cache"):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ def __init__(

def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
from sglang.srt.torch_memory_saver_adapter import torch_memory_saver_adapter

self.memory_saver_adapter = torch_memory_saver_adapter(
self.server_args.enable_memory_saver
)

if not self.is_draft_worker:
Expand Down Expand Up @@ -539,7 +541,7 @@ def load_model(self):
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region("weights"):
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
Expand Down
Loading
Loading