Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from sglang.srt.environ import envs
from sglang.srt.managers.prefill_delayer import PrefillDelayerSinglePassExecutor
from sglang.srt.mem_cache.base_prefix_cache import DecLockRefParams
from sglang.srt.utils import get_bool_env_var

_ROUTING_KEY_POLICY_DEBUG_LOG = get_bool_env_var("SGLANG_ROUTING_KEY_POLICY_DEBUG_LOG")
Expand Down Expand Up @@ -701,16 +700,18 @@ def add_chunked_req(self, req: Req):

@contextmanager
def _lock_node(self, last_node: TreeNode):
dec_lock_params = None
try:
result = self.tree_cache.inc_lock_ref(last_node)
if self.tree_cache.supports_swa() and self.tree_cache.is_tree_cache():
swa_uuid_for_lock = result.swa_uuid_for_lock
if self.tree_cache.is_tree_cache():
# init_load_back may revive SWA/Mamba tombstones while this
# temporary admission lock is held. Release must mirror the
# exact nodes skipped at acquire time.
dec_lock_params = result.to_dec_params()
Comment thread
ispobock marked this conversation as resolved.
yield None
finally:
if self.tree_cache.supports_swa() and self.tree_cache.is_tree_cache():
self.tree_cache.dec_lock_ref(
last_node, DecLockRefParams(swa_uuid_for_lock=swa_uuid_for_lock)
)
if dec_lock_params is not None:
self.tree_cache.dec_lock_ref(last_node, dec_lock_params)
else:
self.tree_cache.dec_lock_ref(last_node)

Expand Down
20 changes: 19 additions & 1 deletion python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.unified_cache_components.tree_component import (
ComponentType,
)


@runtime_checkable
Expand Down Expand Up @@ -94,17 +97,32 @@ class IncLockRefResult:

delta: Optional[int] = None
swa_uuid_for_lock: Optional[int] = None
# Component nodes that were tombstones at acquire time. Replaying this set
# at release prevents a short-lived lock from consuming a later load-back or
# request lock after that tombstone becomes a valid device value.
skip_lock_node_ids: dict[ComponentType, set[int]] = dataclasses.field(
default_factory=dict
)

def to_dec_params(self) -> "DecLockRefParams":
"""Convert to the corresponding DecLockRefParams for dec_lock_ref."""
return DecLockRefParams(swa_uuid_for_lock=self.swa_uuid_for_lock)
return DecLockRefParams(
swa_uuid_for_lock=self.swa_uuid_for_lock,
skip_lock_node_ids={
component_type: set(node_ids)
for component_type, node_ids in self.skip_lock_node_ids.items()
},
)


@dataclasses.dataclass
class DecLockRefParams:
"""Parameters for dec_lock_ref operation."""

swa_uuid_for_lock: Optional[int] = None
skip_lock_node_ids: dict[ComponentType, set[int]] = dataclasses.field(
default_factory=dict
)


@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def acquire_component_lock(
cd.lock_ref += 1
self.cache.evictable_device_leaves.discard(cur)
cur = cur.parent
result = IncLockRefResult(
delta=delta, swa_uuid_for_lock=result.swa_uuid_for_lock
)
result.delta = delta
return result

def release_component_lock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,27 @@ def acquire_component_lock(
ct = self.component_type
cd = node.component_data[ct]
value = cd.value
if value is not None:
if cd.lock_ref == 0:
vlen = len(value)
self.cache.component_evictable_size_[ct] -= vlen
self.cache.component_protected_size_[ct] += vlen
cd.lock_ref += 1
# A node in skip_lock_node_ids was a tombstone when this lock was acquired.
if value is None:
result.skip_lock_node_ids.setdefault(ct, set()).add(node.id)
return result

if cd.lock_ref == 0:
vlen = len(value)
self.cache.component_evictable_size_[ct] -= vlen
self.cache.component_protected_size_[ct] += vlen
cd.lock_ref += 1
return result

def release_component_lock(
self, node: UnifiedTreeNode, params: Optional[DecLockRefParams]
) -> None:
ct = self.component_type
cd = node.component_data[ct]
skip_lock_node_ids = params.skip_lock_node_ids.get(ct, ()) if params else ()
if node.id in skip_lock_node_ids:
return

value = cd.value
if value is not None and cd.lock_ref > 0:
if cd.lock_ref == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def acquire_component_lock(
while cur != root and swa_lock_size < sliding_window_size:
comp = cur.component_data[ct]
if comp.value is None:
result.skip_lock_node_ids.setdefault(ct, set()).add(cur.id)
cur = cur.parent
continue
if comp.lock_ref == 0:
Expand All @@ -385,14 +386,16 @@ def release_component_lock(
ct = self.component_type
root = self.cache.root_node
swa_uuid_for_lock = params.swa_uuid_for_lock if params else None
skip_lock_node_ids = params.skip_lock_node_ids.get(ct, ()) if params else ()
dec_swa = True

# lock_ref == 0 means acquire_component_lock skipped this node
# (tombstone at acquire time) or load_back revived a tombstone between
# acquire and release. Either way, there is nothing for us to undo here.
# A node in skip_lock_node_ids was a tombstone when this lock was acquired.
cur = node
while cur != root and dec_swa:
comp = cur.component_data[ct]
if cur.id in skip_lock_node_ids:
cur = cur.parent
continue
if comp.lock_ref == 0:
cur = cur.parent
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1808,6 +1808,99 @@ def test_hicache_swa_commit_load_back_rebuilds_mapping(self):
n_swa,
)

def test_hicache_swa_temp_lock_does_not_release_restored_tombstone(self):
"""A temporary scheduler lock that skipped a SWA tombstone must not
release later load-back/request locks after the tombstone is restored.
"""
if not self.cfg.has_swa:
self.skipTest("requires SWA")
if self.cfg.has_mamba:
self.skipTest("SWA-only path keeps the chain construction simple")

tree, allocator, _, chain, _ = self._swa_finalize_setup()
leaf = chain[-1]
tombstone = leaf
cd = tombstone.component_data[ComponentType.SWA]
old_swa = cd.value
self.assertIsNotNone(old_swa)

cd.value = None
tree.lru_lists[ComponentType.SWA].remove_node(tombstone)
tree.host_lru_lists[ComponentType.SWA].insert_mru(tombstone)
tree.component_evictable_size_[ComponentType.SWA] -= len(old_swa)

temp_lock = tree.inc_lock_ref(leaf)
self.assertEqual(cd.lock_ref, 0)

xfer = tree.components[ComponentType.SWA].build_hicache_transfers(
leaf, CacheTransferPhase.LOAD_BACK
)[0]
new_swa = allocator.swa_attn_allocator.alloc(int(xfer.host_indices.numel()))
self.assertIsNotNone(new_swa)
xfer.device_indices = new_swa
tree.components[ComponentType.SWA].commit_hicache_transfer(
leaf, CacheTransferPhase.LOAD_BACK, transfers=[xfer]
)

load_back_lock = tree.inc_lock_ref(leaf)
request_lock = tree.inc_lock_ref(leaf)
self.assertEqual(cd.lock_ref, 2)

tree.dec_lock_ref(leaf, temp_lock.to_dec_params())
self.assertEqual(cd.lock_ref, 2)

tree.dec_lock_ref(leaf, load_back_lock.to_dec_params())
tree.dec_lock_ref(leaf, request_lock.to_dec_params())
self.assertEqual(cd.lock_ref, 0)

def test_hicache_mamba_temp_lock_does_not_release_restored_tombstone(self):
"""A temporary scheduler lock that skipped a Mamba tombstone must not
release later load-back/request locks after the tombstone is restored.
"""
if not self.cfg.has_mamba:
self.skipTest("requires Mamba component")
if self.cfg.has_swa:
self.skipTest("Mamba-only path keeps the chain construction simple")

tree, allocator, req_to_token_pool = build_fixture(self.cfg)
seq = self._make_seq(1, 2)
self._insert(tree, allocator, req_to_token_pool, seq)
m = tree.match_prefix(MatchPrefixParams(key=RadixKey(seq)))
node = m.last_device_node
cd = node.component_data[ComponentType.MAMBA]
old_mamba = cd.value
self.assertIsNotNone(old_mamba)
self._simulate_backup(tree, node)

cd.value = None
tree.lru_lists[ComponentType.MAMBA].remove_node(node)
tree.host_lru_lists[ComponentType.MAMBA].insert_mru(node)
tree.component_evictable_size_[ComponentType.MAMBA] -= len(old_mamba)

temp_lock = tree.inc_lock_ref(node)
self.assertEqual(cd.lock_ref, 0)

xfer = tree.components[ComponentType.MAMBA].build_hicache_transfers(
node, CacheTransferPhase.LOAD_BACK
)[0]
new_mamba = req_to_token_pool.mamba_pool.alloc(1)
self.assertIsNotNone(new_mamba)
xfer.device_indices = new_mamba
tree.components[ComponentType.MAMBA].commit_hicache_transfer(
node, CacheTransferPhase.LOAD_BACK, transfers=[xfer]
)

load_back_lock = tree.inc_lock_ref(node)
request_lock = tree.inc_lock_ref(node)
self.assertEqual(cd.lock_ref, 2)

tree.dec_lock_ref(node, temp_lock.to_dec_params())
self.assertEqual(cd.lock_ref, 2)

tree.dec_lock_ref(node, load_back_lock.to_dec_params())
tree.dec_lock_ref(node, request_lock.to_dec_params())
self.assertEqual(cd.lock_ref, 0)

def test_hicache_mixed_backup_evict_insert(self):
"""Complex scenario: backup some, evict, insert new, verify invariants."""
if self._skip_unsupported_hicache_test():
Expand Down
Loading