Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ jobs:
runs-on: 1-gpu-runner
strategy:
fail-fast: false
max-parallel: 5
matrix:
part: [0, 1]
steps:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ class Envs:
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)

# VLM
SGLANG_VLM_CACHE_SIZE_MB = EnvInt(100)
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("")

Expand Down
7 changes: 1 addition & 6 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ def __init__(
self.schedule_low_priority_values_first = schedule_low_priority_values_first

# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
)
self.waiting_queue_radix_tree = RadixCache.create_simulated()

def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
Expand Down
129 changes: 43 additions & 86 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,9 @@
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.common import release_kv_cache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.parser.reasoning_parser import ReasoningParser
Expand Down Expand Up @@ -419,8 +416,8 @@ def __init__(
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)

# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init cache using the existing memory pool
self.init_cache_with_memory_pool()

# Init running status
self.waiting_queue: List[Req] = []
Expand Down Expand Up @@ -693,117 +690,81 @@ def init_tokenizer(self):
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]

def init_memory_pool_and_cache(self):
def init_cache_with_memory_pool(self):
server_args = self.server_args

self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)

params = CacheInitParams(
disable=server_args.disable_radix_cache,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
is_eagle=self.spec_algorithm.is_eagle(),
tp_cache_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
eviction_policy=server_args.radix_eviction_policy,
enable_metrics=self.enable_metrics,
enable_kv_cache_events=self.enable_kv_cache_events,
)
Comment on lines +700 to +714
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great refactoring to introduce CacheInitParams and centralize the cache initialization parameters. I notice that RadixCache, ChunkCache, and SWAChunkCache have been updated to accept the params object directly. However, other cache implementations like RadixCacheCpp, HiRadixCache, SWARadixCache, MambaRadixCache, and LMCRadixCache are still initialized with individual arguments.

To improve consistency and maintainability, it would be beneficial to extend this refactoring to these other cache classes as well, so they also accept a CacheInitParams object in their constructors.


if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
if self.is_hybrid:
ChunkCacheClass = SWAChunkCache
if not self.is_hybrid:
from sglang.srt.mem_cache.chunk_cache import ChunkCache

self.tree_cache = ChunkCache(params)
else:
ChunkCacheClass = ChunkCache
self.tree_cache = ChunkCacheClass(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
)

from sglang.srt.mem_cache.chunk_cache import SWAChunkCache

self.tree_cache = SWAChunkCache(params)
else:

if envs.SGLANG_EXPERIMENTAL_CPP_RADIX_TREE.get():
# lazy import to avoid JIT overhead
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

logger.info("Using experimental C++ radix tree implementation.")
self.tree_cache = RadixCacheCpp(
disable=False,
use_hicache=self.enable_hierarchical_cache,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_cpu_group,
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
enable_metrics=self.enable_metrics,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.tree_cache = RadixCacheCpp(params=params, server_args=server_args)
elif self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
page_size=self.page_size,
eviction_policy=server_args.radix_eviction_policy,
hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=server_args.hicache_io_backend,
hicache_mem_layout=server_args.hicache_mem_layout,
enable_metrics=self.enable_metrics,
hicache_storage_backend=server_args.hicache_storage_backend,
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
model_name=server_args.served_model_name,
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
is_eagle=self.spec_algorithm.is_eagle(),
)
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache

self.tree_cache = HiRadixCache(params=params, server_args=server_args)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
elif self.is_hybrid:
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache

self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
enable_metrics=self.enable_metrics,
params=params, sliding_window_size=self.sliding_window_size
)
elif self.is_hybrid_gdn:
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
enable_metrics=self.enable_metrics,
)
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache

self.tree_cache = MambaRadixCache(params)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
)

self.tree_cache = LMCRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
enable_metrics=self.enable_metrics,
params=params,
model_config=self.model_config,
tp_size=self.tp_size,
rank=self.tp_rank,
tp_group=self.tp_group,
eviction_policy=server_args.radix_eviction_policy,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
enable_metrics=self.enable_metrics,
enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
is_eagle=self.spec_algorithm.is_eagle(),
)
self.tree_cache = RadixCache(params)

if (
server_args.disaggregation_mode == "decode"
Expand All @@ -812,11 +773,7 @@ def init_memory_pool_and_cache(self):
self.decode_offload_manager = DecodeKVCacheOffloadManager(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
tp_group=params.tp_cache_group,
tree_cache=self.tree_cache,
server_args=self.server_args,
)
Expand All @@ -835,7 +792,7 @@ def init_memory_pool_and_cache(self):
)
)

embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get()
init_mm_embedding_cache(embedding_cache_size * 1024 * 1024)

def init_disaggregation(self):
Expand Down
26 changes: 26 additions & 0 deletions python/sglang/srt/mem_cache/cache_init_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Optional

import torch

if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool


@dataclasses.dataclass
class CacheInitParams:
disable: bool
req_to_token_pool: ReqToTokenPool
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator
page_size: int

is_eagle: bool = False
tp_cache_group: Optional[torch.distributed.ProcessGroup] = None
eviction_policy: str = "lru"
disable_finished_insert: bool = False

enable_metrics: bool = False
enable_kv_cache_events: bool = False
31 changes: 9 additions & 22 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,19 @@

import torch

from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.cache_init_params import CacheInitParams


class ChunkCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
def __init__(self, params: CacheInitParams):
self.req_to_token_pool = params.req_to_token_pool
self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator
self.page_size = params.page_size
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
Expand Down Expand Up @@ -89,14 +81,9 @@ def pretty_print(self):
class SWAChunkCache(ChunkCache):
"""ChunkCache with support for hybrid KV cache operations."""

def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
page_size: int,
):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def __init__(self, params: CacheInitParams):
assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
super().__init__(params)

def evict_swa(
self,
Expand Down
Loading
Loading