diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 257d30293009..29b90038ad26 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -4,7 +4,6 @@ from sglang.srt.environ import envs from sglang.srt.managers.prefill_delayer import PrefillDelayerSinglePassExecutor -from sglang.srt.mem_cache.base_prefix_cache import DecLockRefParams from sglang.srt.utils import get_bool_env_var _ROUTING_KEY_POLICY_DEBUG_LOG = get_bool_env_var("SGLANG_ROUTING_KEY_POLICY_DEBUG_LOG") @@ -701,16 +700,18 @@ def add_chunked_req(self, req: Req): @contextmanager def _lock_node(self, last_node: TreeNode): + dec_lock_params = None try: result = self.tree_cache.inc_lock_ref(last_node) - if self.tree_cache.supports_swa() and self.tree_cache.is_tree_cache(): - swa_uuid_for_lock = result.swa_uuid_for_lock + if self.tree_cache.is_tree_cache(): + # init_load_back may revive SWA/Mamba tombstones while this + # temporary admission lock is held. Release must mirror the + # exact nodes skipped at acquire time. + dec_lock_params = result.to_dec_params() yield None finally: - if self.tree_cache.supports_swa() and self.tree_cache.is_tree_cache(): - self.tree_cache.dec_lock_ref( - last_node, DecLockRefParams(swa_uuid_for_lock=swa_uuid_for_lock) - ) + if dec_lock_params is not None: + self.tree_cache.dec_lock_ref(last_node, dec_lock_params) else: self.tree_cache.dec_lock_ref(last_node) diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 8ec5b15b7c5b..ae9df719e3e4 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -22,6 +22,9 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.radix_cache import RadixKey + from sglang.srt.mem_cache.unified_cache_components.tree_component import ( + ComponentType, + ) @runtime_checkable @@ -94,10 +97,22 @@ class IncLockRefResult: delta: Optional[int] = None swa_uuid_for_lock: Optional[int] = None + # Component nodes that were tombstones at acquire time. Replaying this set + # at release prevents a short-lived lock from consuming a later load-back or + # request lock after that tombstone becomes a valid device value. + skip_lock_node_ids: dict[ComponentType, set[int]] = dataclasses.field( + default_factory=dict + ) def to_dec_params(self) -> "DecLockRefParams": """Convert to the corresponding DecLockRefParams for dec_lock_ref.""" - return DecLockRefParams(swa_uuid_for_lock=self.swa_uuid_for_lock) + return DecLockRefParams( + swa_uuid_for_lock=self.swa_uuid_for_lock, + skip_lock_node_ids={ + component_type: set(node_ids) + for component_type, node_ids in self.skip_lock_node_ids.items() + }, + ) @dataclasses.dataclass @@ -105,6 +120,9 @@ class DecLockRefParams: """Parameters for dec_lock_ref operation.""" swa_uuid_for_lock: Optional[int] = None + skip_lock_node_ids: dict[ComponentType, set[int]] = dataclasses.field( + default_factory=dict + ) @dataclasses.dataclass diff --git a/python/sglang/srt/mem_cache/unified_cache_components/full_component.py b/python/sglang/srt/mem_cache/unified_cache_components/full_component.py index c9470f66fbe8..805009dbbc84 100644 --- a/python/sglang/srt/mem_cache/unified_cache_components/full_component.py +++ b/python/sglang/srt/mem_cache/unified_cache_components/full_component.py @@ -166,9 +166,7 @@ def acquire_component_lock( cd.lock_ref += 1 self.cache.evictable_device_leaves.discard(cur) cur = cur.parent - result = IncLockRefResult( - delta=delta, swa_uuid_for_lock=result.swa_uuid_for_lock - ) + result.delta = delta return result def release_component_lock( diff --git a/python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py b/python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py index 59dfc476840d..45397c656479 100644 --- a/python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py +++ b/python/sglang/srt/mem_cache/unified_cache_components/mamba_component.py @@ -217,12 +217,16 @@ def acquire_component_lock( ct = self.component_type cd = node.component_data[ct] value = cd.value - if value is not None: - if cd.lock_ref == 0: - vlen = len(value) - self.cache.component_evictable_size_[ct] -= vlen - self.cache.component_protected_size_[ct] += vlen - cd.lock_ref += 1 + # A node in skip_lock_node_ids was a tombstone when this lock was acquired. + if value is None: + result.skip_lock_node_ids.setdefault(ct, set()).add(node.id) + return result + + if cd.lock_ref == 0: + vlen = len(value) + self.cache.component_evictable_size_[ct] -= vlen + self.cache.component_protected_size_[ct] += vlen + cd.lock_ref += 1 return result def release_component_lock( @@ -230,6 +234,10 @@ def release_component_lock( ) -> None: ct = self.component_type cd = node.component_data[ct] + skip_lock_node_ids = params.skip_lock_node_ids.get(ct, ()) if params else () + if node.id in skip_lock_node_ids: + return + value = cd.value if value is not None and cd.lock_ref > 0: if cd.lock_ref == 1: diff --git a/python/sglang/srt/mem_cache/unified_cache_components/swa_component.py b/python/sglang/srt/mem_cache/unified_cache_components/swa_component.py index ec78978a118d..306ad2996984 100644 --- a/python/sglang/srt/mem_cache/unified_cache_components/swa_component.py +++ b/python/sglang/srt/mem_cache/unified_cache_components/swa_component.py @@ -362,6 +362,7 @@ def acquire_component_lock( while cur != root and swa_lock_size < sliding_window_size: comp = cur.component_data[ct] if comp.value is None: + result.skip_lock_node_ids.setdefault(ct, set()).add(cur.id) cur = cur.parent continue if comp.lock_ref == 0: @@ -385,14 +386,16 @@ def release_component_lock( ct = self.component_type root = self.cache.root_node swa_uuid_for_lock = params.swa_uuid_for_lock if params else None + skip_lock_node_ids = params.skip_lock_node_ids.get(ct, ()) if params else () dec_swa = True - # lock_ref == 0 means acquire_component_lock skipped this node - # (tombstone at acquire time) or load_back revived a tombstone between - # acquire and release. Either way, there is nothing for us to undo here. + # A node in skip_lock_node_ids was a tombstone when this lock was acquired. cur = node while cur != root and dec_swa: comp = cur.component_data[ct] + if cur.id in skip_lock_node_ids: + cur = cur.parent + continue if comp.lock_ref == 0: cur = cur.parent continue diff --git a/test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py b/test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py index 10cf822f0a69..878c3b60c6cf 100644 --- a/test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py +++ b/test/registered/unit/mem_cache/test_unified_radix_cache_unittest.py @@ -1808,6 +1808,99 @@ def test_hicache_swa_commit_load_back_rebuilds_mapping(self): n_swa, ) + def test_hicache_swa_temp_lock_does_not_release_restored_tombstone(self): + """A temporary scheduler lock that skipped a SWA tombstone must not + release later load-back/request locks after the tombstone is restored. + """ + if not self.cfg.has_swa: + self.skipTest("requires SWA") + if self.cfg.has_mamba: + self.skipTest("SWA-only path keeps the chain construction simple") + + tree, allocator, _, chain, _ = self._swa_finalize_setup() + leaf = chain[-1] + tombstone = leaf + cd = tombstone.component_data[ComponentType.SWA] + old_swa = cd.value + self.assertIsNotNone(old_swa) + + cd.value = None + tree.lru_lists[ComponentType.SWA].remove_node(tombstone) + tree.host_lru_lists[ComponentType.SWA].insert_mru(tombstone) + tree.component_evictable_size_[ComponentType.SWA] -= len(old_swa) + + temp_lock = tree.inc_lock_ref(leaf) + self.assertEqual(cd.lock_ref, 0) + + xfer = tree.components[ComponentType.SWA].build_hicache_transfers( + leaf, CacheTransferPhase.LOAD_BACK + )[0] + new_swa = allocator.swa_attn_allocator.alloc(int(xfer.host_indices.numel())) + self.assertIsNotNone(new_swa) + xfer.device_indices = new_swa + tree.components[ComponentType.SWA].commit_hicache_transfer( + leaf, CacheTransferPhase.LOAD_BACK, transfers=[xfer] + ) + + load_back_lock = tree.inc_lock_ref(leaf) + request_lock = tree.inc_lock_ref(leaf) + self.assertEqual(cd.lock_ref, 2) + + tree.dec_lock_ref(leaf, temp_lock.to_dec_params()) + self.assertEqual(cd.lock_ref, 2) + + tree.dec_lock_ref(leaf, load_back_lock.to_dec_params()) + tree.dec_lock_ref(leaf, request_lock.to_dec_params()) + self.assertEqual(cd.lock_ref, 0) + + def test_hicache_mamba_temp_lock_does_not_release_restored_tombstone(self): + """A temporary scheduler lock that skipped a Mamba tombstone must not + release later load-back/request locks after the tombstone is restored. + """ + if not self.cfg.has_mamba: + self.skipTest("requires Mamba component") + if self.cfg.has_swa: + self.skipTest("Mamba-only path keeps the chain construction simple") + + tree, allocator, req_to_token_pool = build_fixture(self.cfg) + seq = self._make_seq(1, 2) + self._insert(tree, allocator, req_to_token_pool, seq) + m = tree.match_prefix(MatchPrefixParams(key=RadixKey(seq))) + node = m.last_device_node + cd = node.component_data[ComponentType.MAMBA] + old_mamba = cd.value + self.assertIsNotNone(old_mamba) + self._simulate_backup(tree, node) + + cd.value = None + tree.lru_lists[ComponentType.MAMBA].remove_node(node) + tree.host_lru_lists[ComponentType.MAMBA].insert_mru(node) + tree.component_evictable_size_[ComponentType.MAMBA] -= len(old_mamba) + + temp_lock = tree.inc_lock_ref(node) + self.assertEqual(cd.lock_ref, 0) + + xfer = tree.components[ComponentType.MAMBA].build_hicache_transfers( + node, CacheTransferPhase.LOAD_BACK + )[0] + new_mamba = req_to_token_pool.mamba_pool.alloc(1) + self.assertIsNotNone(new_mamba) + xfer.device_indices = new_mamba + tree.components[ComponentType.MAMBA].commit_hicache_transfer( + node, CacheTransferPhase.LOAD_BACK, transfers=[xfer] + ) + + load_back_lock = tree.inc_lock_ref(node) + request_lock = tree.inc_lock_ref(node) + self.assertEqual(cd.lock_ref, 2) + + tree.dec_lock_ref(node, temp_lock.to_dec_params()) + self.assertEqual(cd.lock_ref, 2) + + tree.dec_lock_ref(node, load_back_lock.to_dec_params()) + tree.dec_lock_ref(node, request_lock.to_dec_params()) + self.assertEqual(cd.lock_ref, 0) + def test_hicache_mixed_backup_evict_insert(self): """Complex scenario: backup some, evict, insert new, verify invariants.""" if self._skip_unsupported_hicache_test():