diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 7648abe1e797..1aa861232004 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -357,7 +357,6 @@ jobs: runs-on: 1-gpu-runner strategy: fail-fast: false - max-parallel: 5 matrix: part: [0, 1] steps: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 9c57b428f44f..fd2040345dcd 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -282,6 +282,7 @@ class Envs: SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) # VLM + SGLANG_VLM_CACHE_SIZE_MB = EnvInt(100) SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_RESIZE_RESAMPLE = EnvStr("") diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index d9af0c88337e..6564f11883d7 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -95,12 +95,7 @@ def __init__( self.schedule_low_priority_values_first = schedule_low_priority_values_first # It is used to find the matching prefix for in-batch prefix caching. - self.waiting_queue_radix_tree = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=1, - disable=False, - ) + self.waiting_queue_radix_tree = RadixCache.create_simulated() def calc_priority(self, waiting_queue: List[Req]) -> bool: if self.policy == CacheAgnosticPolicy.FCFS: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e1bd79331479..f5d94978254a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -148,12 +148,9 @@ ) from sglang.srt.managers.session_controller import Session from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length -from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.common import release_kv_cache -from sglang.srt.mem_cache.hiradix_cache import HiRadixCache -from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.parser.reasoning_parser import ReasoningParser @@ -419,8 +416,8 @@ def __init__( # Init metrics stats self.init_metrics(tp_rank, pp_rank, dp_rank) - # Init memory pool and cache - self.init_memory_pool_and_cache() + # Init cache using the existing memory pool + self.init_cache_with_memory_pool() # Init running status self.waiting_queue: List[Req] = [] @@ -693,117 +690,81 @@ def init_tokenizer(self): reasoning_parser.detector.think_end_token, add_special_tokens=False )[0] - def init_memory_pool_and_cache(self): + def init_cache_with_memory_pool(self): server_args = self.server_args self.req_to_token_pool, self.token_to_kv_pool_allocator = ( self.tp_worker.get_memory_pool() ) + params = CacheInitParams( + disable=server_args.disable_radix_cache, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, + is_eagle=self.spec_algorithm.is_eagle(), + tp_cache_group=( + self.attn_tp_cpu_group + if self.server_args.enable_dp_attention + else self.tp_cpu_group + ), + eviction_policy=server_args.radix_eviction_policy, + enable_metrics=self.enable_metrics, + enable_kv_cache_events=self.enable_kv_cache_events, + ) + if ( server_args.chunked_prefill_size is not None and server_args.disable_radix_cache ): - if self.is_hybrid: - ChunkCacheClass = SWAChunkCache + if not self.is_hybrid: + from sglang.srt.mem_cache.chunk_cache import ChunkCache + + self.tree_cache = ChunkCache(params) else: - ChunkCacheClass = ChunkCache - self.tree_cache = ChunkCacheClass( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, - ) + + from sglang.srt.mem_cache.chunk_cache import SWAChunkCache + + self.tree_cache = SWAChunkCache(params) else: + if envs.SGLANG_EXPERIMENTAL_CPP_RADIX_TREE.get(): # lazy import to avoid JIT overhead from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp logger.info("Using experimental C++ radix tree implementation.") - self.tree_cache = RadixCacheCpp( - disable=False, - use_hicache=self.enable_hierarchical_cache, - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - tp_cache_group=self.tp_cpu_group, - page_size=self.page_size, - hicache_ratio=server_args.hicache_ratio, - hicache_size=server_args.hicache_size, - hicache_write_policy=server_args.hicache_write_policy, - enable_metrics=self.enable_metrics, - enable_kv_cache_events=self.enable_kv_cache_events, - ) + self.tree_cache = RadixCacheCpp(params=params, server_args=server_args) elif self.enable_hierarchical_cache: - self.tree_cache = HiRadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - tp_cache_group=( - self.attn_tp_cpu_group - if self.server_args.enable_dp_attention - else self.tp_cpu_group - ), - page_size=self.page_size, - eviction_policy=server_args.radix_eviction_policy, - hicache_ratio=server_args.hicache_ratio, - hicache_size=server_args.hicache_size, - hicache_write_policy=server_args.hicache_write_policy, - hicache_io_backend=server_args.hicache_io_backend, - hicache_mem_layout=server_args.hicache_mem_layout, - enable_metrics=self.enable_metrics, - hicache_storage_backend=server_args.hicache_storage_backend, - hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, - model_name=server_args.served_model_name, - storage_backend_extra_config=server_args.hicache_storage_backend_extra_config, - is_eagle=self.spec_algorithm.is_eagle(), - ) + from sglang.srt.mem_cache.hiradix_cache import HiRadixCache + + self.tree_cache = HiRadixCache(params=params, server_args=server_args) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter ) elif self.is_hybrid: + from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache + self.tree_cache = SWARadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - sliding_window_size=self.sliding_window_size, - page_size=self.page_size, - disable=server_args.disable_radix_cache, - is_eagle=self.spec_algorithm.is_eagle(), - enable_metrics=self.enable_metrics, + params=params, sliding_window_size=self.sliding_window_size ) elif self.is_hybrid_gdn: - self.tree_cache = MambaRadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, - disable=server_args.disable_radix_cache, - enable_metrics=self.enable_metrics, - ) + from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache + + self.tree_cache = MambaRadixCache(params) elif server_args.enable_lmcache: from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( LMCRadixCache, ) self.tree_cache = LMCRadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, - disable=server_args.disable_radix_cache, - enable_metrics=self.enable_metrics, + params=params, model_config=self.model_config, tp_size=self.tp_size, rank=self.tp_rank, tp_group=self.tp_group, - eviction_policy=server_args.radix_eviction_policy, ) else: - self.tree_cache = RadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, - disable=server_args.disable_radix_cache, - enable_metrics=self.enable_metrics, - enable_kv_cache_events=self.enable_kv_cache_events, - eviction_policy=server_args.radix_eviction_policy, - is_eagle=self.spec_algorithm.is_eagle(), - ) + self.tree_cache = RadixCache(params) if ( server_args.disaggregation_mode == "decode" @@ -812,11 +773,7 @@ def init_memory_pool_and_cache(self): self.decode_offload_manager = DecodeKVCacheOffloadManager( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - tp_group=( - self.attn_tp_cpu_group - if self.server_args.enable_dp_attention - else self.tp_cpu_group - ), + tp_group=params.tp_cache_group, tree_cache=self.tree_cache, server_args=self.server_args, ) @@ -835,7 +792,7 @@ def init_memory_pool_and_cache(self): ) ) - embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100")) + embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get() init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) def init_disaggregation(self): diff --git a/python/sglang/srt/mem_cache/cache_init_params.py b/python/sglang/srt/mem_cache/cache_init_params.py new file mode 100644 index 000000000000..06ca57521c97 --- /dev/null +++ b/python/sglang/srt/mem_cache/cache_init_params.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator + from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + + +@dataclasses.dataclass +class CacheInitParams: + disable: bool + req_to_token_pool: ReqToTokenPool + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator + page_size: int + + is_eagle: bool = False + tp_cache_group: Optional[torch.distributed.ProcessGroup] = None + eviction_policy: str = "lru" + disable_finished_insert: bool = False + + enable_metrics: bool = False + enable_kv_cache_events: bool = False diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index aae3294b5c2d..dcc899e46532 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -6,27 +6,19 @@ import torch -from sglang.srt.mem_cache.allocator import ( - BaseTokenToKVPoolAllocator, - SWATokenToKVPoolAllocator, -) +from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req + from sglang.srt.mem_cache.cache_init_params import CacheInitParams class ChunkCache(BasePrefixCache): - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - page_size: int, - ): - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.page_size = page_size + def __init__(self, params: CacheInitParams): + self.req_to_token_pool = params.req_to_token_pool + self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator + self.page_size = params.page_size if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device else: @@ -89,14 +81,9 @@ def pretty_print(self): class SWAChunkCache(ChunkCache): """ChunkCache with support for hybrid KV cache operations.""" - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: SWATokenToKVPoolAllocator, - page_size: int, - ): - super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size) - assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + def __init__(self, params: CacheInitParams): + assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + super().__init__(params) def evict_swa( self, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 87ed378ec986..8a2815b59235 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1,20 +1,17 @@ +from __future__ import annotations + import heapq import json import logging import threading import time -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult -from sglang.srt.mem_cache.memory_pool import ( - MHATokenToKVPool, - MLATokenToKVPool, - ReqToTokenPool, -) +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, MLATokenToKVPoolHost, @@ -22,62 +19,49 @@ from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.metrics.collector import StorageMetricsCollector +if TYPE_CHECKING: + from sglang.srt.mem_cache.cache_init_params import CacheInitParams + from sglang.srt.server_args import ServerArgs + logger = logging.getLogger(__name__) class HiRadixCache(RadixCache): - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - tp_cache_group: torch.distributed.ProcessGroup, - page_size: int, - hicache_ratio: float, - hicache_size: int, - hicache_write_policy: str, - hicache_io_backend: str, - hicache_mem_layout: str, - enable_metrics: bool, - eviction_policy: str = "lru", - hicache_storage_backend: Optional[str] = None, - hicache_storage_prefetch_policy: Optional[str] = "best_effort", - model_name: Optional[str] = None, - storage_backend_extra_config: Optional[str] = None, - is_eagle: bool = False, - ): - - if hicache_io_backend == "direct": - if hicache_mem_layout == "page_first": - hicache_mem_layout = "page_first_direct" + def __init__(self, params: CacheInitParams, server_args: ServerArgs): + if server_args.hicache_io_backend == "direct": + # FIXME: move this logic into server_args parsing + if server_args.hicache_mem_layout == "page_first": + server_args.hicache_mem_layout = "page_first_direct" logger.warning( "Page first layout is not supported with direct IO backend, switching to page first direct layout" ) - self.kv_cache = token_to_kv_pool_allocator.get_kvcache() + self.page_size = params.page_size + self.kv_cache = params.token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.kv_cache, - hicache_ratio, - hicache_size, - page_size, - hicache_mem_layout, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.kv_cache, - hicache_ratio, - hicache_size, - page_size, - hicache_mem_layout, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, ) else: raise ValueError(f"HiRadixCache only supports MHA and MLA yet") - self.tp_group = tp_cache_group + self.tp_group = params.tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) - self.enable_storage = hicache_storage_backend is not None - self.enable_storage_metrics = self.enable_storage and enable_metrics + self.enable_storage = server_args.hicache_storage_backend is not None + self.enable_storage_metrics = self.enable_storage and params.enable_metrics ( extra_config, @@ -85,35 +69,37 @@ def __init__( prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys, - ) = self._parse_storage_backend_extra_config(storage_backend_extra_config) + ) = self._parse_storage_backend_extra_config( + server_args.hicache_storage_backend_extra_config + ) self.prefetch_threshold = prefetch_threshold self.prefetch_timeout_base = prefetch_timeout_base self.prefetch_timeout_per_page = ( - page_size / 1024 * prefetch_timeout_per_ki_token + self.page_size / 1024 * prefetch_timeout_per_ki_token ) self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys # TODO: support more timeout check functions self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func - self.prefetch_stop_policy = hicache_storage_prefetch_policy + self.prefetch_stop_policy = server_args.hicache_storage_prefetch_policy self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( - token_to_kv_pool_allocator, + params.token_to_kv_pool_allocator, self.token_to_kv_pool_host, - page_size, + self.page_size, self.tp_group, load_cache_event=self.load_cache_event, - write_policy=hicache_write_policy, - io_backend=hicache_io_backend, - storage_backend=hicache_storage_backend, + write_policy=server_args.hicache_write_policy, + io_backend=server_args.hicache_io_backend, + storage_backend=server_args.hicache_storage_backend, prefetch_threshold=self.prefetch_threshold, - model_name=model_name, + model_name=server_args.served_model_name, storage_backend_extra_config=extra_config, ) if self.enable_storage_metrics: # TODO: support pp labels = { - "storage_backend": hicache_storage_backend, + "storage_backend": server_args.hicache_storage_backend, "tp_rank": self.cache_controller.tp_rank, "dp_rank": self.cache_controller.dp_rank, } @@ -128,19 +114,11 @@ def __init__( self.ongoing_backup = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( - 1 if hicache_write_policy == "write_through" else 2 + 1 if server_args.hicache_write_policy == "write_through" else 2 ) self.load_back_threshold = 10 - super().__init__( - req_to_token_pool, - token_to_kv_pool_allocator, - page_size, - disable=False, - eviction_policy=eviction_policy, - is_eagle=is_eagle, - enable_metrics=enable_metrics, - ) + super().__init__(params=params) def _parse_storage_backend_extra_config( self, storage_backend_extra_config: Optional[str] diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 7a9b4c0d0926..31a37f86dec1 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -28,7 +28,6 @@ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool from sglang.srt.mem_cache.radix_cache import ( RadixKey, _key_match_page_size1, @@ -37,6 +36,7 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req + from sglang.srt.mem_cache.cache_init_params import CacheInitParams import logging @@ -320,28 +320,23 @@ def sanity_check(self, tree_cache: "MambaRadixCache"): class MambaRadixCache(BasePrefixCache): - def __init__( - self, - req_to_token_pool: HybridReqToTokenPool, - token_to_kv_pool_allocator: TokenToKVPoolAllocator, - page_size: int, - disable: bool = False, - enable_metrics: bool = False, - ): - assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator) - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - - assert page_size == 1, "Only support page_size=1 in mamba radix cache now." - self.page_size = page_size - self.disable = disable + def __init__(self, params: CacheInitParams): + assert isinstance(params.token_to_kv_pool_allocator, TokenToKVPoolAllocator) + self.req_to_token_pool = params.req_to_token_pool + self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator + + assert ( + params.page_size == 1 + ), "Only support page_size=1 in mamba radix cache now." + self.page_size = params.page_size + self.disable = params.disable if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device else: self.device = torch.device("cpu") - if enable_metrics: + if params.enable_metrics: self.init_metrics_collector() self.key_match_fn = _key_match_page_size1 diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index e528cf116bac..fae975de8ee3 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -1,5 +1,6 @@ from __future__ import annotations +from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.utils import convert_to_bigram_key """ @@ -26,7 +27,7 @@ import time from collections import defaultdict from functools import lru_cache, partial -from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union import torch @@ -35,7 +36,6 @@ BlockRemoved, BlockStored, ) -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.evict_policy import ( EvictionStrategy, @@ -46,7 +46,6 @@ MRUStrategy, PriorityStrategy, ) -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -187,28 +186,19 @@ def get_child_key(key: RadixKey, page_size: int = 1): class RadixCache(BasePrefixCache): - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - page_size: int, - disable: bool = False, - enable_metrics: bool = False, - enable_kv_cache_events: bool = False, - eviction_policy: str = "lru", - is_eagle: bool = False, - disable_finished_insert: bool = False, - ): - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.page_size = page_size - self.disable = disable - self.enable_kv_cache_events = enable_kv_cache_events + def __init__(self, params: CacheInitParams): + self.disable = params.disable + self.req_to_token_pool = params.req_to_token_pool + self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator + self.page_size = params.page_size + self.enable_kv_cache_events = params.enable_kv_cache_events + self.is_eagle = params.is_eagle + self.disable_finished_insert = params.disable_finished_insert + self.eviction_policy = params.eviction_policy.lower() + self.kv_event_queue = [] - self.is_eagle = is_eagle - self.disable_finished_insert = disable_finished_insert - if enable_metrics: + if params.enable_metrics: self.init_metrics_collector() if self.token_to_kv_pool_allocator: @@ -220,27 +210,45 @@ def __init__( self.key_match_fn = _key_match_page_size1 self.get_child_key_fn = get_child_key else: - self.key_match_fn = partial(_key_match_paged, page_size=page_size) - self.get_child_key_fn = partial(get_child_key, page_size=page_size) + self.key_match_fn = partial(_key_match_paged, page_size=self.page_size) + self.get_child_key_fn = partial(get_child_key, page_size=self.page_size) - if eviction_policy.lower() == "lru": + if self.eviction_policy == "lru": self.eviction_strategy: EvictionStrategy = LRUStrategy() - elif eviction_policy.lower() == "lfu": + elif self.eviction_policy == "lfu": self.eviction_strategy: EvictionStrategy = LFUStrategy() - elif eviction_policy.lower() == "fifo": + elif self.eviction_policy == "fifo": self.eviction_strategy: EvictionStrategy = FIFOStrategy() - elif eviction_policy.lower() == "mru": + elif self.eviction_policy == "mru": self.eviction_strategy: EvictionStrategy = MRUStrategy() - elif eviction_policy.lower() == "filo": + elif self.eviction_policy == "filo": self.eviction_strategy: EvictionStrategy = FILOStrategy() - elif eviction_policy.lower() == "priority": + elif self.eviction_policy == "priority": self.eviction_strategy: EvictionStrategy = PriorityStrategy() else: raise ValueError( - f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu', 'fifo', 'mru', 'filo', 'priority'." + f"Unknown eviction policy: {self.eviction_policy}. Supported policies: 'lru', 'lfu', 'fifo', 'mru', 'filo', 'priority'." ) self.reset() + @classmethod + def create_simulated( + self, + disable: bool = False, + mock_allocator: Optional[Any] = None, + page_size: int = 1, + enable_kv_cache_events: bool = False, + ) -> RadixCache: + """Init a radix cache without memory pools for simulation purpose.""" + params = CacheInitParams( + disable=disable, + req_to_token_pool=None, + token_to_kv_pool_allocator=mock_allocator, + page_size=page_size, + enable_kv_cache_events=enable_kv_cache_events, + ) + return RadixCache(params) + ##### Public API ##### def reset(self): @@ -743,7 +751,7 @@ def take_events(self): if __name__ == "__main__": - tree = RadixCache(None, None, page_size=1, disable=False) + tree = RadixCache.create_simulated() # Example token id sequences (as lists of ints) tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None)) diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index 4938c2e69fde..e8f187b467ef 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -6,62 +6,36 @@ import torch -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import ( IOHandle, RadixTreeCpp, TreeNodeCpp, ) -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req + from sglang.srt.mem_cache.cache_init_params import CacheInitParams + from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) class RadixCacheCpp(BasePrefixCache): - def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor: - """ - Merge a list of tensors into a single tensor. - Args: - l (List[torch.Tensor]): List of tensors to merge. - Returns: - torch.Tensor: Merged tensor. - """ - if len(l) == 0: - return torch.empty(0, dtype=torch.int64, device=self.device) - elif len(l) == 1: - return l[0] - else: - return torch.cat(l) - def __init__( self, - disable: bool, - use_hicache: bool, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - tp_cache_group: torch.distributed.ProcessGroup, - page_size: int, - hicache_ratio: float, - hicache_size: int, - hicache_write_policy: str, - enable_metrics: bool = False, - enable_kv_cache_events: bool = False, - hicache_oracle: bool = False, + params: CacheInitParams, + server_args: ServerArgs, enable_write_cancel: bool = False, ): - self.disable = disable + self.disable = params.disable self.enable_write_cancel = enable_write_cancel assert ( - enable_kv_cache_events is False + params.enable_kv_cache_events is False ), "HiRadixCache does not support kv cache events yet" - self.kv_cache = token_to_kv_pool_allocator.get_kvcache() # record the nodes with ongoing write through self.ongoing_write_through: Set[IOHandle] = set() @@ -69,22 +43,23 @@ def __init__( self.ongoing_load_back: Set[IOHandle] = set() # todo: dynamically adjust the threshold self.write_through_threshold = ( - 1 if hicache_write_policy == "write_through" else 2 + 1 if server_args.hicache_write_policy == "write_through" else 2 ) - self.device = token_to_kv_pool_allocator.device - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.req_to_token_pool = req_to_token_pool - self.page_size = page_size + self.device = self.token_to_kv_pool_allocator.device + self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator + self.req_to_token_pool = params.req_to_token_pool + self.page_size = params.page_size + self.kv_cache = self.token_to_kv_pool_allocator.get_kvcache() - self.tp_group = tp_cache_group + self.tp_group = params.tp_cache_group - if enable_metrics: + if params.enable_metrics: self.init_metrics_collector() - if not use_hicache: + if not server_args.enable_hierarchical_cache: self.tree = RadixTreeCpp( disabled=self.disable, - page_size=page_size, + page_size=self.page_size, host_size=None, # no host cache, this should be removed in the future write_through_threshold=self.write_through_threshold, ) @@ -93,6 +68,21 @@ def __init__( raise NotImplementedError("Host cache is not supported yet") + def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor: + """ + Merge a list of tensors into a single tensor. + Args: + l (List[torch.Tensor]): List of tensors to merge. + Returns: + torch.Tensor: Merged tensor. + """ + if len(l) == 0: + return torch.empty(0, dtype=torch.int64, device=self.device) + elif len(l) == 1: + return l[0] + else: + return torch.cat(l) + def reset(self): if self.cache_controller is not None: # need to clear the acks before resetting the cache controller diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index 820653b4ba8d..c697ebc79051 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -6,9 +6,7 @@ import torch -from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode try: @@ -25,6 +23,7 @@ if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req + from sglang.srt.mem_cache.cache_init_params import CacheInitParams logger = logging.getLogger(__name__) @@ -69,27 +68,13 @@ class LMCRadixCache(RadixCache): def __init__( self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, - page_size: int, - disable: bool = False, - enable_metrics: bool = False, - enable_kv_cache_events: bool = False, + params: CacheInitParams, model_config: Optional["ModelConfig"] = None, tp_size: int = 1, rank: int = 0, tp_group: Optional[torch.distributed.ProcessGroup] = None, - eviction_policy: str = "lru", ): - super().__init__( - req_to_token_pool=req_to_token_pool, - token_to_kv_pool_allocator=token_to_kv_pool_allocator, - page_size=page_size, - disable=disable, - enable_metrics=enable_metrics, - enable_kv_cache_events=enable_kv_cache_events, - eviction_policy=eviction_policy, - ) + super().__init__(params) kvcache = self.token_to_kv_pool_allocator.get_kvcache() self.lmcache_connector = LMCacheLayerwiseConnector( @@ -271,12 +256,17 @@ def pretty_print(self): # type: ignore[override] if __name__ == "__main__": - cache = LMCRadixCache( + from sglang.srt.mem_cache.cache_init_params import CacheInitParams + + params = CacheInitParams( req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1, disable=False, enable_kv_cache_events=False, + ) + cache = LMCRadixCache( + params=params, model_config=None, tp_size=1, rank=0, diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 2e97aca51c7b..33f55c1bb01d 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -30,7 +30,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import ( RadixKey, _key_match_page_size1, @@ -329,22 +329,13 @@ def sanity_check(self, tree_cache: "SWARadixCache"): class SWARadixCache(BasePrefixCache): - def __init__( - self, - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: SWATokenToKVPoolAllocator, - sliding_window_size: int, - page_size: int, - disable: bool = False, - is_eagle: bool = False, - enable_metrics: bool = False, - ): - assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.page_size = page_size - self.disable = disable - self.is_eagle = is_eagle + def __init__(self, params: CacheInitParams, sliding_window_size: int): + assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + self.req_to_token_pool = params.req_to_token_pool + self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator + self.page_size = params.page_size + self.disable = params.disable + self.is_eagle = params.is_eagle if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device @@ -355,15 +346,15 @@ def __init__( self.key_match_fn = _key_match_page_size1 self.get_child_key_fn = get_child_key else: - self.key_match_fn = partial(_key_match_paged, page_size=page_size) - self.get_child_key_fn = partial(get_child_key, page_size=page_size) + self.key_match_fn = partial(_key_match_paged, page_size=self.page_size) + self.get_child_key_fn = partial(get_child_key, page_size=self.page_size) - if is_eagle: + if self.is_eagle: self.key_convert_fn = convert_to_bigram_key else: self.key_convert_fn = lambda key: key - if enable_metrics: + if params.enable_metrics: self.init_metrics_collector() self.sliding_window_size = sliding_window_size diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0ac752b93fb9..dae29f89c8d1 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -56,6 +56,7 @@ from multiprocessing.reduction import ForkingPickler from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -94,6 +95,9 @@ from sglang.srt.environ import envs from sglang.srt.metrics.func_timer import enable_func_timer +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs + logger = logging.getLogger(__name__) show_time_cost = False @@ -2754,7 +2758,7 @@ def with_value(self, new_value: T): self._value = None -def require_mlp_tp_gather(server_args): +def require_mlp_tp_gather(server_args: ServerArgs): """ Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups. """ @@ -2777,7 +2781,7 @@ def require_mlp_tp_gather(server_args): return False -def require_attn_tp_gather(server_args): +def require_attn_tp_gather(server_args: ServerArgs): """ Check if the input of attention is scattered. """ @@ -2791,11 +2795,11 @@ def require_attn_tp_gather(server_args): return False -def require_gathered_buffer(server_args): +def require_gathered_buffer(server_args: ServerArgs): return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args) -def require_mlp_sync(server_args): +def require_mlp_sync(server_args: ServerArgs): return server_args.enable_dp_attention or require_gathered_buffer(server_args) diff --git a/test/manual/test_schedule_policy.py b/test/manual/test_schedule_policy.py index 2be092e3146a..747a247e83b8 100644 --- a/test/manual/test_schedule_policy.py +++ b/test/manual/test_schedule_policy.py @@ -14,7 +14,7 @@ class TestSchedulePolicy(CustomTestCase): def setUp(self): - self.tree_cache = RadixCache(None, None, False) + self.tree_cache = RadixCache.create_simulated() def test_init_with_cache_aware_policy(self): policy = SchedulePolicy( @@ -47,10 +47,10 @@ def test_init_with_unknown_policy(self): ) def test_init_with_disabled_cache(self): - disabled_tree_cache = RadixCache(None, None, disable=True, page_size=1) + tree_cache = RadixCache.create_simulated(disable=True) policy = SchedulePolicy( policy="lpm", - tree_cache=disabled_tree_cache, + tree_cache=tree_cache, enable_hierarchical_cache=True, enable_priority_scheduling=False, schedule_low_priority_values_first=False, @@ -58,7 +58,7 @@ def test_init_with_disabled_cache(self): self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) def test_calc_priority_fcfs(self): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams()), Req(3, "a b c", [1, 2, 3], SamplingParams()), @@ -79,16 +79,15 @@ def test_calc_priority_fcfs(self): self.assertEqual(waiting_queue[2].rid, 2) def test_calc_priority_priority_enabled_fcfs_scheduling(self): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() + r1 = Req(1, "a b", [1, 2], SamplingParams()) + r2 = Req(3, "a b c", [1, 2, 3], SamplingParams()) + r3 = Req(2, "a", [1], SamplingParams()) + r1.priority, r1.time_stats.wait_queue_entry_time = 1, 1 + r2.priority, r2.time_stats.wait_queue_entry_time = 0, 1 + r3.priority, r3.time_stats.wait_queue_entry_time = 0, 0 - waiting_queue = [ - Req(1, "a b", [1, 2], SamplingParams()), - Req(3, "a b c", [1, 2, 3], SamplingParams()), - Req(2, "a", [1], SamplingParams()), - ] - waiting_queue[0].priority, waiting_queue[0].queue_time_start = 1, 1 - waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 - waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 + waiting_queue = [r1, r2, r3] policy = SchedulePolicy( policy="fcfs", @@ -98,6 +97,7 @@ def test_calc_priority_priority_enabled_fcfs_scheduling(self): schedule_low_priority_values_first=False, ) policy.calc_priority(waiting_queue) + # Check if priority enabled fcfs ordering is applied. self.assertEqual(waiting_queue[0].rid, 1) self.assertEqual(waiting_queue[1].rid, 2) @@ -106,16 +106,15 @@ def test_calc_priority_priority_enabled_fcfs_scheduling(self): def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values_first( self, ): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() + r1 = Req(1, "a b", [1, 2], SamplingParams()) + r2 = Req(3, "a b c", [1, 2, 3], SamplingParams()) + r3 = Req(2, "a", [1], SamplingParams()) + r1.priority, r1.time_stats.wait_queue_entry_time = -1, 1 + r2.priority, r2.time_stats.wait_queue_entry_time = 0, 1 + r3.priority, r3.time_stats.wait_queue_entry_time = 0, 0 - waiting_queue = [ - Req(1, "a b", [1, 2], SamplingParams()), - Req(3, "a b c", [1, 2, 3], SamplingParams()), - Req(2, "a", [1], SamplingParams()), - ] - waiting_queue[0].priority, waiting_queue[0].queue_time_start = -1, 0 - waiting_queue[1].priority, waiting_queue[1].queue_time_start = 0, 1 - waiting_queue[2].priority, waiting_queue[2].queue_time_start = 0, 0 + waiting_queue = [r1, r2, r3] policy = SchedulePolicy( policy="fcfs", @@ -131,7 +130,7 @@ def test_calc_priority_priority_enabled_fcfs_scheduling_with_low_priority_values self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_longest_output_first_scheduling(self): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1000)), @@ -153,7 +152,7 @@ def test_calc_priority_longest_output_first_scheduling(self): self.assertEqual(waiting_queue[2].rid, 3) def test_calc_priority_priority_enabled_longest_output_first_scheduling(self): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=1), @@ -177,7 +176,7 @@ def test_calc_priority_priority_enabled_longest_output_first_scheduling(self): def test_calc_priority_priority_enabled_longest_output_first_scheduling_with_low_priority_values_first( self, ): - tree_cache = RadixCache(None, None, False) + tree_cache = RadixCache.create_simulated() waiting_queue = [ Req(1, "a b", [1, 2], SamplingParams(max_new_tokens=1), priority=0), diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index ae93c415121a..d72cad94e734 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -6,6 +6,7 @@ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey @@ -186,13 +187,14 @@ def test_mamba_radix_cache_1(self): kvcache=pool, need_sort=False, ) - # setup radix cache - tree = MambaRadixCache( + params = CacheInitParams( req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=allocator, page_size=1, disable=False, ) + # setup radix cache + tree = MambaRadixCache(params=params) def make_dummy_req(): sampling_params = SamplingParams( diff --git a/test/srt/test_radix_cache_unit.py b/test/srt/test_radix_cache_unit.py index 8cb75fb0bf84..eadb338abf12 100644 --- a/test/srt/test_radix_cache_unit.py +++ b/test/srt/test_radix_cache_unit.py @@ -240,11 +240,9 @@ def test_init_variations(self): with self.subTest( page_size=page_size, disable=disable, enable_events=enable_events ): - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=page_size, + cache = RadixCache.create_simulated( disable=disable, + page_size=page_size, enable_kv_cache_events=enable_events, ) @@ -257,9 +255,7 @@ def test_init_variations(self): def test_reset(self): """Test reset method.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() # Insert some data cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) @@ -275,12 +271,7 @@ def test_insert_and_match_basic(self): """Test basic insert and match operations.""" for disable_cache in [False, True]: with self.subTest(disable_cache=disable_cache): - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=1, - disable=disable_cache, - ) + cache = RadixCache.create_simulated(disable=disable_cache) key = RadixKey([1, 2, 3]) value = torch.tensor([10, 20, 30], dtype=torch.int64) @@ -309,9 +300,7 @@ def test_insert_and_match_basic(self): def test_insert_with_none_value(self): """Test insert with None value (should use token_ids as list).""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() key = RadixKey([1, 2, 3]) prefix_len = cache.insert(key, None) @@ -322,9 +311,7 @@ def test_insert_with_none_value(self): def test_total_size(self): """Test total_size calculation.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() self.assertEqual(cache.total_size(), 0) @@ -344,11 +331,8 @@ def test_kv_cache_events(self): for page_size, enable_events in test_cases: with self.subTest(page_size=page_size, enable_events=enable_events): - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=page_size, - enable_kv_cache_events=enable_events, + cache = RadixCache.create_simulated( + page_size=page_size, enable_kv_cache_events=enable_events ) # Insert data @@ -374,11 +358,8 @@ def test_kv_cache_events_with_eviction(self): mock_allocator = unittest.mock.Mock() mock_allocator.device = torch.device("cpu") - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=mock_allocator, - page_size=1, - enable_kv_cache_events=True, + cache = RadixCache.create_simulated( + mock_allocator=mock_allocator, enable_kv_cache_events=True ) # Insert and then evict data @@ -400,9 +381,7 @@ def test_kv_cache_events_with_eviction(self): def test_extra_key_isolation(self): """Test that keys with different extra_key values are isolated.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() # Insert same token sequence with different extra keys cache.insert( @@ -442,9 +421,7 @@ def test_extra_key_isolation(self): def test_lock_ref_operations(self): """Test lock reference counting operations.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() # Insert sequence cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) @@ -471,11 +448,7 @@ def test_evict_functionality(self): mock_allocator = unittest.mock.Mock() mock_allocator.device = torch.device("cpu") - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=mock_allocator, - page_size=1, - ) + cache = RadixCache.create_simulated(mock_allocator=mock_allocator) # Insert sequences cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64)) @@ -500,11 +473,7 @@ def test_page_alignment_boundary(self): for page_size, sequence_length in test_cases: with self.subTest(page_size=page_size, sequence_length=sequence_length): - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=page_size, - ) + cache = RadixCache.create_simulated(page_size=page_size) tokens = list(range(sequence_length)) cache.insert(RadixKey(tokens), torch.tensor(tokens, dtype=torch.int64)) @@ -518,9 +487,7 @@ def test_page_alignment_boundary(self): def test_pretty_print_basic(self): """Test pretty_print produces output.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) @@ -532,9 +499,7 @@ def test_pretty_print_basic(self): def test_all_values_flatten(self): """Test all_values_flatten method.""" - cache = RadixCache( - req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1 - ) + cache = RadixCache.create_simulated() cache.insert(RadixKey([1, 2]), torch.tensor([10, 20], dtype=torch.int64)) cache.insert(RadixKey([3, 4]), torch.tensor([30, 40], dtype=torch.int64)) @@ -549,11 +514,7 @@ def test_advanced_prefix_match_with_node_splits(self): """Advanced prefix matching: splits inside nodes and across pages.""" for page_size in [1, 2]: with self.subTest(page_size=page_size): - cache = RadixCache( - req_to_token_pool=None, - token_to_kv_pool_allocator=None, - page_size=page_size, - ) + cache = RadixCache.create_simulated(page_size=page_size) # Insert a long sequence that will be split later. seq1 = [1, 2, 3, 4, 5, 6, 7, 8] diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py index b11435b8f9f2..2d01f90bd05d 100644 --- a/test/srt/test_swa_unittest.py +++ b/test/srt/test_swa_unittest.py @@ -3,6 +3,7 @@ import torch from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -110,11 +111,13 @@ def test_swa_radix_cache_1(self): ) # setup radix cache tree = SWARadixCache( - req_to_token_pool=req_to_token_pool, - token_to_kv_pool_allocator=allocator, + params=CacheInitParams( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + disable=False, + page_size=1, + ), sliding_window_size=sliding_window_size, - page_size=1, - disable=False, ) # test @@ -241,12 +244,14 @@ def test_swa_radix_cache_eagle(self): ) # setup radix cache tree = SWARadixCache( - req_to_token_pool=req_to_token_pool, - token_to_kv_pool_allocator=allocator, + params=CacheInitParams( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + page_size=1, + disable=False, + is_eagle=True, + ), sliding_window_size=sliding_window_size, - page_size=1, - disable=False, - is_eagle=True, ) # test