Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
614b781
support cp in hicache
vladnosiv Feb 26, 2026
d4849ed
Merge branch 'main' into hicache-and-cp
vladnosiv Mar 17, 2026
ce27bde
Merge branch 'main' into hicache-and-cp
vladnosiv Mar 19, 2026
4ba927a
Merge branch 'main' into hicache-and-cp
vladnosiv Mar 23, 2026
1640288
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 2, 2026
81ded49
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 6, 2026
794c777
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 7, 2026
4bd0278
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 8, 2026
2629f95
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 14, 2026
66eaf6f
simplify diff
vladnosiv Apr 14, 2026
e361af5
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 16, 2026
9bc09c3
fix automerge
vladnosiv Apr 16, 2026
2f40399
add qwen3-30b hicache+cp+tp+storage test
vladnosiv Apr 16, 2026
4e04e92
fix file storage suffixes
vladnosiv Apr 16, 2026
f6f5975
add dsv32 hicache+cp+tp+storage test
vladnosiv Apr 16, 2026
8b2edbe
add dsv32 hicache+NO-CP+tp+storage test
vladnosiv Apr 16, 2026
5a1e6fb
fix hicache suffixes for cp
vladnosiv Apr 16, 2026
0a2cdef
remove cp info from hicache keys
vladnosiv Apr 16, 2026
2792423
replace tests
vladnosiv Apr 16, 2026
0a5747d
Merge branch 'main' into hicache-and-cp
hzh0425 Apr 17, 2026
fa6499c
fix missed moe-dp-size 2 in cp test
vladnosiv Apr 17, 2026
b273628
Merge branch 'main' into hicache-and-cp
xiezhq-hermann Apr 17, 2026
d2d05e6
Merge remote-tracking branch 'origin/main' into hicache-and-cp
vladnosiv Apr 20, 2026
5a2a4ff
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 21, 2026
643b083
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 22, 2026
c6f6cb3
fix bug in conflict resolve
vladnosiv Apr 22, 2026
49a3cb5
Merge branch 'main' into hicache-and-cp
vladnosiv Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 67 additions & 49 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def __init__(
page_size: int,
tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event,
attn_cp_group: Optional[torch.distributed.ProcessGroup] = None,
attn_tp_group: Optional[torch.distributed.ProcessGroup] = None,
write_policy: str = "write_through_selective",
io_backend: str = "",
storage_backend: Optional[str] = None,
Expand All @@ -261,11 +263,12 @@ def __init__(
storage_backend_extra_config: Optional[dict] = None,
pp_rank: int = 0,
pp_size: int = 1,
attn_cp_rank: int = 0,
attn_cp_size: int = 1,
enable_storage_metrics: bool = False,
):
self.tp_group = tp_group
self.attn_cp_group = attn_cp_group
self.attn_tp_group = attn_tp_group
self.prefetch_sync_groups: List[torch.distributed.ProcessGroup] = []
self.mem_pool_device_allocator = token_to_kv_pool_allocator
mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool
Expand All @@ -282,8 +285,6 @@ def __init__(
self.storage_backend_type = None
self.pp_rank = pp_rank
self.pp_size = pp_size
self.attn_cp_rank = attn_cp_rank
self.attn_cp_size = attn_cp_size
self.enable_storage_metrics = enable_storage_metrics

# Default storage page IO functions (may be overridden by attach).
Expand Down Expand Up @@ -337,6 +338,51 @@ def __init__(
# Preserve the historical error shape on init for unknown backends.
raise ValueError(f"Failed to create storage backend: {e}") from e

def get_attn_cp_rank_and_size(self) -> tuple[int, int]:
"""Derive CP rank/size from the attn_cp process group."""
if self.attn_cp_group is not None:
return (
torch.distributed.get_rank(group=self.attn_cp_group),
torch.distributed.get_world_size(group=self.attn_cp_group),
)
return 0, 1

def _create_prefetch_sync_groups(self) -> None:
from sglang.srt.distributed.parallel_state import create_custom_parallel_group

self.prefetch_sync_groups = []
seen_rank_sets = set()

if self.attn_cp_group is not None or self.attn_tp_group is not None:
base_groups = [self.attn_cp_group, self.attn_tp_group]
else:
base_groups = [self.tp_group]

for group in base_groups:
if group is None or torch.distributed.get_world_size(group=group) == 1:
continue
group_ranks = tuple(torch.distributed.get_process_group_ranks(group))
if group_ranks in seen_rank_sets:
continue
seen_rank_sets.add(group_ranks)
self.prefetch_sync_groups.append(
create_custom_parallel_group(
group_ranks=list(group_ranks), backend="gloo"
)
)

def _destroy_prefetch_sync_groups(self) -> None:
for group in self.prefetch_sync_groups:
try:
torch.distributed.destroy_process_group(group)
except Exception:
pass
self.prefetch_sync_groups = []

def _all_reduce_prefetch_groups(self, tensor: torch.Tensor, op) -> None:
for group in self.prefetch_sync_groups:
torch.distributed.all_reduce(tensor, op=op, group=group)

def _start_storage_threads(self):
"""Start storage prefetch/backup threads and their queues.

Expand Down Expand Up @@ -467,17 +513,9 @@ def attach_storage_backend(
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0

# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
if self.tp_world_size > 1:
from sglang.srt.distributed.parallel_state import (
create_custom_parallel_group,
)

group_ranks = torch.distributed.get_process_group_ranks(self.tp_group)
self.prefetch_tp_group = create_custom_parallel_group(
group_ranks=group_ranks, backend="gloo"
)
# Use dedicated gloo groups so storage prefetch sync is isolated
# from other collectives and consistent across CPxTP participants.
self._create_prefetch_sync_groups()

# Select the get and set functions
self.page_get_func = self._generic_page_get
Expand All @@ -502,15 +540,7 @@ def attach_storage_backend(
self._stop_storage_threads()
except Exception:
pass
try:
if hasattr(self, "prefetch_tp_group"):
try:
torch.distributed.destroy_process_group(self.prefetch_tp_group)
except Exception:
pass
self.prefetch_tp_group = None
except Exception:
pass
self._destroy_prefetch_sync_groups()
try:
if (
hasattr(self, "storage_backend")
Expand Down Expand Up @@ -547,19 +577,8 @@ def detach_storage_backend(self):
# to avoid flipping `enable_storage` flags while threads are still alive.
raise RuntimeError("Stop storage threads failed; detach aborted.") from e

# Best-effort destroy process group created for storage ops.
try:
if (
hasattr(self, "prefetch_tp_group")
and self.prefetch_tp_group is not None
):
try:
torch.distributed.destroy_process_group(self.prefetch_tp_group)
except Exception:
pass
self.prefetch_tp_group = None
except Exception:
pass
# Best-effort destroy process groups created for storage ops.
self._destroy_prefetch_sync_groups()

# Best-effort close (some backends rely on GC/destructor).
try:
Expand Down Expand Up @@ -613,13 +632,15 @@ def _generate_storage_config(
and tp_lcm_size > self.tp_size
)

attn_cp_rank, attn_cp_size = self.get_attn_cp_rank_and_size()

return HiCacheStorageConfig(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
pp_rank=self.pp_rank,
pp_size=self.pp_size,
attn_cp_rank=self.attn_cp_rank,
attn_cp_size=self.attn_cp_size,
attn_cp_rank=attn_cp_rank,
attn_cp_size=attn_cp_size,
is_mla_model=is_mla_backend,
enable_storage_metrics=self.enable_storage_metrics,
is_page_first_layout=self.mem_pool_host.layout == "page_first",
Expand Down Expand Up @@ -963,16 +984,13 @@ def prefetch_thread_func(self):
if operation is None:
continue
hash_value, storage_hit_count = self._storage_hit_query(operation)
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
)
torch.distributed.all_reduce(
storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.prefetch_tp_group,
)
storage_hit_count = storage_hit_count_tensor.item()
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
)
self._all_reduce_prefetch_groups(
storage_hit_count_tensor, torch.distributed.ReduceOp.MIN
)
storage_hit_count = storage_hit_count_tensor.item()

if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,14 @@ def init_cache_with_memory_pool(self):
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
attn_cp_cache_group=self.attn_cp_cpu_group,
attn_tp_cache_group=self.attn_tp_cpu_group,
eviction_policy=server_args.radix_eviction_policy,
enable_metrics=self.enable_metrics,
enable_kv_cache_events=self.enable_kv_cache_events,
enable_mamba_extra_buffer=server_args.enable_mamba_extra_buffer(),
pp_rank=self.pp_rank,
pp_size=self.pp_size,
attn_cp_rank=self.attn_cp_rank,
attn_cp_size=self.attn_cp_size,
chunked_prefill_size=effective_chunked_prefill_size,
sliding_window_size=self.sliding_window_size,
)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/mem_cache/cache_init_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class CacheInitParams:

is_eagle: bool = False
tp_cache_group: Optional[torch.distributed.ProcessGroup] = None
attn_cp_cache_group: Optional[torch.distributed.ProcessGroup] = None
attn_tp_cache_group: Optional[torch.distributed.ProcessGroup] = None
eviction_policy: str = "lru"
disable_finished_insert: bool = False

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/mem_cache/hi_mamba_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
prefetch_threshold=prefetch_threshold,
load_cache_event=self.load_cache_event,
enable_storage_metrics=self.enable_storage_metrics,
attn_cp_group=params.attn_cp_cache_group,
attn_tp_group=params.attn_tp_cache_group,
)
self._apply_storage_runtime_config(
storage_backend=server_args.hicache_storage_backend,
Expand Down
88 changes: 46 additions & 42 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
)

self.tp_group = params.tp_cache_group
self.attn_cp_group = params.attn_cp_cache_group
self.attn_tp_group = params.attn_tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.pp_rank = params.pp_rank
self.pp_size = params.pp_size
self.attn_cp_rank = params.attn_cp_rank
self.attn_cp_size = params.attn_cp_size
self.enable_storage = server_args.hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and params.enable_metrics
self.extra_metric_labels = server_args.extra_metric_labels
Expand Down Expand Up @@ -130,6 +130,8 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
prefetch_threshold=prefetch_threshold,
enable_storage_metrics=self.enable_storage_metrics,
load_cache_event=self.load_cache_event,
attn_cp_group=self.attn_cp_group,
attn_tp_group=self.attn_tp_group,
)
else:
self.cache_controller = HiCacheController(
Expand All @@ -138,6 +140,8 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
self.page_size,
self.tp_group,
load_cache_event=self.load_cache_event,
attn_cp_group=self.attn_cp_group,
attn_tp_group=self.attn_tp_group,
write_policy=server_args.hicache_write_policy,
io_backend=server_args.hicache_io_backend,
storage_backend=server_args.hicache_storage_backend,
Expand All @@ -146,8 +150,6 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):
storage_backend_extra_config=extra_config,
pp_rank=self.pp_rank,
pp_size=self.pp_size,
attn_cp_rank=self.attn_cp_rank,
attn_cp_size=self.attn_cp_size,
enable_storage_metrics=self.enable_storage_metrics,
)
self._apply_storage_runtime_config(
Expand Down Expand Up @@ -184,6 +186,24 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs):

super().__init__(params=params)

def _all_reduce_attn_groups(self, tensor: torch.Tensor, op):
reduced = False
for group in (self.attn_cp_group, self.attn_tp_group):
if group is not None and torch.distributed.get_world_size(group=group) > 1:
torch.distributed.all_reduce(tensor, op=op, group=group)
reduced = True
if not reduced and self.tp_world_size > 1:
torch.distributed.all_reduce(tensor, op=op, group=self.tp_group)

def _barrier_attn_groups(self):
waited = False
for group in (self.attn_cp_group, self.attn_tp_group):
if group is not None and torch.distributed.get_world_size(group=group) > 1:
torch.distributed.barrier(group=group)
waited = True
if not waited and self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)

def shutdown(self):
"""Best-effort auto-detach of storage backend on process shutdown.

Expand Down Expand Up @@ -220,14 +240,17 @@ def _apply_storage_runtime_config(
self.enable_storage_metrics = enable_storage_metrics

if self.enable_storage_metrics:
attn_cp_rank, attn_cp_size = (
self.cache_controller.get_attn_cp_rank_and_size()
)
labels = {
"storage_backend": storage_backend,
"tp_rank": self.cache_controller.tp_rank,
"dp_rank": self.cache_controller.dp_rank,
"pp_rank": self.cache_controller.pp_rank,
"pp_size": self.cache_controller.pp_size,
"attn_cp_rank": self.cache_controller.attn_cp_rank,
"attn_cp_size": self.cache_controller.attn_cp_size,
"attn_cp_rank": attn_cp_rank,
"attn_cp_size": attn_cp_size,
}
if extra_metric_labels:
labels.update(extra_metric_labels)
Expand Down Expand Up @@ -741,13 +764,8 @@ def writing_check(self, write_back=False):
break
finish_count += 1
queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
if self.tp_world_size > 1:
# synchronize TP workers to make the same update to radix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
# Keep cache state transitions identical across CPxTP participants.
self._all_reduce_attn_groups(queue_size, torch.distributed.ReduceOp.MIN)

finish_count = int(queue_size.item())
while finish_count > 0:
Expand Down Expand Up @@ -1074,10 +1092,7 @@ def drain_storage_control_queues(self):
],
dtype=torch.int,
)
if self.tp_world_size > 1:
torch.distributed.all_reduce(
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
self._all_reduce_attn_groups(qsizes, torch.distributed.ReduceOp.MIN)

n_revoke, n_backup, n_release = map(int, qsizes.tolist())
self._drain_storage_control_queues_impl(
Expand Down Expand Up @@ -1118,18 +1133,13 @@ def can_terminate_prefetch(self, operation: PrefetchOperation):
return True

operation_terminated = operation.is_terminated()
if self.tp_world_size > 1:
states = torch.tensor(
[1 - int(can_terminate), int(operation_terminated)],
dtype=torch.int,
)
torch.distributed.all_reduce(
states,
op=torch.distributed.ReduceOp.MAX,
group=self.tp_group,
)
can_terminate = states[0].item() == 0
operation_terminated = states[1].item() == 1
states = torch.tensor(
[1 - int(can_terminate), int(operation_terminated)],
dtype=torch.int,
)
self._all_reduce_attn_groups(states, torch.distributed.ReduceOp.MAX)
can_terminate = states[0].item() == 0
operation_terminated = states[1].item() == 1
# the operation should be terminated if it is already terminated on any TP worker
# or it meets the termination condition on all TP workers
can_terminate = can_terminate or operation_terminated
Expand Down Expand Up @@ -1159,17 +1169,12 @@ def check_prefetch_progress(self, req_id: str) -> bool:
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")

min_completed_tokens = completed_tokens
if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
# Synchronize workers before mutating host cache tree state.
completed_tokens_tensor = torch.tensor(min_completed_tokens, dtype=torch.int)
self._all_reduce_attn_groups(
completed_tokens_tensor, torch.distributed.ReduceOp.MIN
)
min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
Expand Down Expand Up @@ -1494,8 +1499,7 @@ def release_aborted_request(self, rid: str):
return

completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)
self._barrier_attn_groups()
last_host_node.release_host()
del self.ongoing_prefetch[rid]
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
Expand Down
Loading
Loading