Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
22f87aa
Support Multiple Torch Memory Saver for Multi-Stage-Awake
hebiao064 Jun 8, 2025
eead001
rm unnecessary code
hebiao064 Jun 8, 2025
b834ea6
add test
hebiao064 Jun 8, 2025
11e0461
uncomment disable_cuda_graph
hebiao064 Jun 8, 2025
a336515
fix server up issue
hebiao064 Jun 8, 2025
ceb314f
remove debugging code
hebiao064 Jun 8, 2025
7095f02
polish the test
hebiao064 Jun 8, 2025
d4eae34
Address reviewers feedback
hebiao064 Jun 8, 2025
062b349
modify test
hebiao064 Jun 9, 2025
10d9e26
fix test del model issue
hebiao064 Jun 9, 2025
0ea29ca
simplify code
hebiao064 Jun 9, 2025
bf5e4b3
Merge branch 'main' into bhe/support_multiple_tms
zhaochenyang20 Jun 9, 2025
effced3
removing unnecesary code comment
hebiao064 Jun 10, 2025
723609f
upd
hebiao064 Jun 10, 2025
c9d7be8
Merge branch 'bhe/support_multiple_tms' of https://github.com/sgl-pro…
hebiao064 Jun 10, 2025
3ca173f
Tag based Resume
hebiao064 Jun 11, 2025
ff53de8
fix
hebiao064 Jun 11, 2025
3253de6
update tms usage
hebiao064 Jun 15, 2025
fc8b0df
fix comments
hebiao064 Jun 15, 2025
29fab5b
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 15, 2025
b0be907
fix
hebiao064 Jun 16, 2025
33ef688
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 16, 2025
1e04dc5
fix
hebiao064 Jun 16, 2025
0b077ed
remove some comments
hebiao064 Jun 16, 2025
e41a542
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 17, 2025
6cb1ff2
bump tms
hebiao064 Jun 17, 2025
0f35284
Merge branch 'main' into bhe/tag_based_resume
fzyzcjy Jun 17, 2025
bd5306b
fix test
hebiao064 Jun 17, 2025
3e62d19
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 17, 2025
33a897b
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 17, 2025
443fae8
fix
hebiao064 Jun 18, 2025
abfb5db
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 18, 2025
83f9678
fix
hebiao064 Jun 18, 2025
7b19a92
Merge branch 'main' into bhe/tag_based_resume
zhyncs Jun 18, 2025
669637d
fix
hebiao064 Jun 18, 2025
c39a27d
Update test_release_memory_occupation.py
zhaochenyang20 Jun 18, 2025
35adb70
Merge branch 'main' into bhe/tag_based_resume
zhaochenyang20 Jun 19, 2025
4cb9c2e
add tp 2 tests
zhaochen20 Jun 19, 2025
0311c14
fix
hebiao064 Jun 19, 2025
e4a2613
Merge branch 'main' into bhe/tag_based_resume
hebiao064 Jun 19, 2025
14ae324
set to 0.85 for tp2
hebiao064 Jun 19, 2025
aa1e259
Merge branch 'bhe/tag_based_resume' of https://github.com/hebiao064/s…
hebiao064 Jun 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
decord = ["decord"]
test = [
"accelerate",
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights"
3 changes: 2 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch
from torch.distributed import ProcessGroup

from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
self.max_context_len = max_context_len
self.device = device
self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region():
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,
Expand Down
13 changes: 5 additions & 8 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,15 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100):
self.tokenizer_manager.get_weights_by_name(obj, None)
)

def release_memory_occupation(self):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None)
)

def resume_memory_occupation(self):
"""Resume GPU occupation."""
obj = ResumeMemoryOccupationReqInput()
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ResumeMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None)
Expand Down Expand Up @@ -670,11 +668,9 @@ def _launch_subprocesses(

scheduler_procs = []
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)

scheduler_pipe_readers = []

nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
Expand Down Expand Up @@ -710,6 +706,7 @@ def _launch_subprocesses(
writer,
),
)

with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:

@dataclass
class ReleaseMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None


@dataclass
Expand All @@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:

@dataclass
class ResumeMemoryOccupationReqInput:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None


@dataclass
Expand Down
48 changes: 32 additions & 16 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
Expand Down Expand Up @@ -450,8 +451,6 @@ def __init__(
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()

# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
Expand Down Expand Up @@ -2227,23 +2226,40 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return GetWeightsByNameReqOutput(parameter)

def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(
caller_name="release_memory_occupation"
)
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
tags = recv_req.tags
import subprocess

if tags is None:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()

if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)

return ReleaseMemoryOccupationReqOutput()

def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state

if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)

return ResumeMemoryOccupationReqOutput()

def slow_down(self, recv_req: SlowDownReqInput):
Expand Down
10 changes: 6 additions & 4 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import triton
import triton.language as tl

from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2

Expand All @@ -54,14 +55,15 @@ def __init__(
device: str,
enable_memory_saver: bool,
):

memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)

self.size = size
self.max_context_len = max_context_len
self.device = device
with memory_saver_adapter.region():
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
Expand Down Expand Up @@ -292,7 +294,7 @@ def __init__(
)

def _create_buffers(self):
with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
Expand Down Expand Up @@ -610,7 +612,7 @@ def __init__(
else:
self.custom_mem_pool = None

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
Expand Down Expand Up @@ -753,7 +755,7 @@ def __init__(
end_layer,
)

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import (
get_tp_group,
get_world_group,
Expand Down Expand Up @@ -222,6 +223,7 @@ def __init__(

def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args

self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
Expand Down Expand Up @@ -547,7 +549,7 @@ def load_model(self):
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()

with self.memory_saver_adapter.region():
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
Expand Down
34 changes: 19 additions & 15 deletions python/sglang/srt/torch_memory_saver_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import threading
import time
from abc import ABC
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext

try:
import torch_memory_saver

_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
_memory_saver = torch_memory_saver.torch_memory_saver
import_error = None
except ImportError as e:
import_error = e
Expand Down Expand Up @@ -38,13 +40,13 @@ def check_validity(self, caller_name):
def configure_subprocess(self):
raise NotImplementedError

def region(self):
def region(self, tag: str):
raise NotImplementedError

def pause(self):
def pause(self, tag: str):
raise NotImplementedError

def resume(self):
def resume(self, tag: str):
raise NotImplementedError

@property
Expand All @@ -53,21 +55,23 @@ def enabled(self):


class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
"""Adapter for TorchMemorySaver with tag-based control"""

def configure_subprocess(self):
return torch_memory_saver.configure_subprocess()

def region(self):
return _primary_memory_saver.region()
def region(self, tag: str):
return _memory_saver.region(tag=tag)

def pause(self):
return _primary_memory_saver.pause()
def pause(self, tag: str):
return _memory_saver.pause(tag=tag)

def resume(self):
return _primary_memory_saver.resume()
def resume(self, tag: str):
return _memory_saver.resume(tag=tag)

@property
def enabled(self):
return _primary_memory_saver.enabled
return _memory_saver is not None and _memory_saver.enabled


class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
Expand All @@ -76,13 +80,13 @@ def configure_subprocess(self):
yield

@contextmanager
def region(self):
def region(self, tag: str):
yield

def pause(self):
def pause(self, tag: str):
pass

def resume(self):
def resume(self, tag: str):
pass

@property
Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# General test models
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"

Expand Down
Loading
Loading