diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 65f2e4cb8c9c..9f9046a2d050 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -152,9 +152,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: len(reusing) <= 1 ), "only one chunked request may reuse req_pool_idx in a batch" assert all( - reqs[i].inflight_middle_chunks > 0 or reqs[i].kv_committed_len > 0 - for i in reusing - ), "reusing request must be chunked or have committed KV" + reqs[i].kv_committed_len > 0 for i in reusing + ), "reusing request must have committed KV" need_size = len(reqs) - len(reusing) if need_size > len(self.free_slots): @@ -1655,11 +1654,16 @@ def get_next_disagg_decode_batch_to_run( # Process pending prebuilt batch: output processing + filter + merge new_prebuilt_batch = self.get_new_prebuilt_batch() if new_prebuilt_batch: - assert self.chunked_req is None + # C10: dead assert removed — post-C4 chunked-resume not in waiting_queue. self.batch_result_processor.process_batch_result_prebuilt( new_prebuilt_batch ) - new_prebuilt_batch.filter_batch() + # Defensive: chunked prefill is a prefill-side concept; decode-side + # prebuilt batches shouldn't carry has_pending_chunk reqs. The + # waiting_queue invariant is checked by _assert_invariants in sync + # mode; this flag protects against any future code that would route + # a chunked req through the disagg decode path. + new_prebuilt_batch.filter_batch(exclude_chunked_req=True) if not new_prebuilt_batch.is_empty(): if self.running_batch.is_empty(): self.running_batch = new_prebuilt_batch diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index bb07f4012a5c..b4272ccebb57 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -514,7 +514,7 @@ def process_batch_result_disagg_prefill( for i, (req, next_token_id) in enumerate( zip(batch.reqs, next_token_ids, strict=True) ): - if req.inflight_middle_chunks <= 0: + if req.pending_middle_outputs <= 0: req.time_stats.set_prefill_finished_time() # There is no output_ids for prefill @@ -556,6 +556,8 @@ def process_batch_result_disagg_prefill( # This can happen if the grammar is not set correctly or the token is invalid. error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}" release_kv_cache(req, self.tree_cache) + # audit D-prefill-1: disagg PREFILL release path + self._deactivate(req) prepare_abort( req, error_message, @@ -564,7 +566,7 @@ def process_batch_result_disagg_prefill( req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished - req.inflight_middle_chunks -= 1 + req.pending_middle_outputs -= 1 if req.return_logprob: extend_logprob_start_len = extend_logprob_start_len_per_req[i] @@ -640,6 +642,8 @@ def process_disagg_prefill_inflight_queue( undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done release_kv_cache(req, self.tree_cache) # unlock the tree + # audit D-prefill-2: disagg PREFILL release path + self._deactivate(req) req.finished_reason = FINISH_LENGTH(length=0) # FIXME: clean up req's data in transfer engine if hasattr(req.disagg_kv_sender, "clear"): @@ -655,6 +659,8 @@ def process_disagg_prefill_inflight_queue( logger.warning(error_message) req.time_stats.trace_ctx.abort(abort_info={"reason": error_message}) release_kv_cache(req, self.tree_cache) # unlock the tree + # audit D-prefill-3: disagg PREFILL release path + self._deactivate(req) prepare_abort( req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR ) @@ -725,30 +731,30 @@ def get_transferred_rids(self: Scheduler) -> List[str]: return transferred_rids def process_prefill_chunk(self: Scheduler) -> None: - chunked_req_to_exclude = set() - if self.chunked_req: - chunked_req_to_exclude.add(self.chunked_req) - maybe_cache_unfinished_req(self.chunked_req, self.tree_cache, chunked=True) - if self.enable_overlap: - # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved - self.chunked_req.tmp_end_idx = min( - len(self.chunked_req.fill_ids), - len(self.chunked_req.origin_input_ids), - ) - else: - self.send_kv_chunk(self.chunked_req) - self.running_batch.batch_is_full = False + # audit C10: disagg PREFILL chunked-resume now lives in active_reqs + # (same as sync mode post-C4); iterate chunked_reqs() view. + for req in self.chunked_reqs(): + if not req.is_dllm(): + maybe_cache_unfinished_req(req, self.tree_cache, chunked=True) + if self.enable_overlap: + # Delay KV transfer to process_batch_result_disagg_prefill + # when overlap is enabled to ensure results are resolved. + req.tmp_end_idx = min( + len(req.fill_ids), + len(req.origin_input_ids), + ) + else: + self.send_kv_chunk(req) + self.running_batch.batch_is_full = False if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.last_batch.chunked_req: - # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. - # We need to discard it. - chunked_req_to_exclude.add(self.last_batch.chunked_req) - last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) + # Drop chunked-resume reqs from last_batch — running_batch runs + # decode forward and admitting a mid-prefill req there breaks + # shape + KV accounting. The dropped reqs stay in + # self.active_reqs and re-enter via the next iter's Stage A + # stash + admission cycle. + self.last_batch.filter_batch(exclude_chunked_req=True) if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False diff --git a/python/sglang/srt/dllm/mixin/scheduler.py b/python/sglang/srt/dllm/mixin/scheduler.py index 3fbff753118a..73bb7bfdbb08 100644 --- a/python/sglang/srt/dllm/mixin/scheduler.py +++ b/python/sglang/srt/dllm/mixin/scheduler.py @@ -201,7 +201,7 @@ def _update_state_for_batch( if can_run_list: self.dllm_manager.add_staging_reqs(can_run_list) - self.dllm_manager.increment_inflight_middle_chunks() + self.dllm_manager.increment_pending_middle_outputs() self.adder = adder self.can_run_list = can_run_list @@ -259,7 +259,6 @@ def process_dllm_incoming_reqs( req.init_next_round_input(self.tree_cache) res = adder.add_one_req( req, - has_chunked_req=True, truncation_align_size=self.truncation_align_size, ) @@ -339,10 +338,10 @@ def is_empty(self) -> bool: return True return len(self.waiting_queue) == 0 - def increment_inflight_middle_chunks(self) -> None: + def increment_pending_middle_outputs(self) -> None: """Increment chunked count for all staging requests.""" for req in self.staging_queue: - req.inflight_middle_chunks += 1 + req.pending_middle_outputs += 1 def filter_finished_reqs(self) -> None: """Remove finished requests from both queues.""" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ce81a8cadded..dd059014038c 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -823,10 +823,22 @@ def __init__( # The prefix length that is inserted into the tree cache self.cache_protected_len: int = 0 - # Whether or not if it is chunked. It increments whenever - # it is chunked, and decrement whenever chunked request is - # processed. - self.inflight_middle_chunks = 0 + # Counter of middle-block prefill forwards that have been admitted + # but not yet output-processed for this req. Increments at admission + # for non-last chunks; decrements at output_processor. In PP, can + # exceed 1 because multiple microbatches may hold the same chunked + # req in flight concurrently. In non-PP, oscillates 0/1 within each + # iter. Used by output_processor to know whether this forward's + # sample is real (==0) or garbage (>0). + self.pending_middle_outputs = 0 + + # Persistent (cross-iter) flag set by admission when this req's + # current admission was truncated (more chunks remain). Cleared + # when last chunk is admitted (truncated=False) or on retract. + # Used by Stage A stash detection, filter_batch exclusion, and + # add_one_req's reuse-vs-fresh branch. Independent of pending_middle_outputs + # counter (transient) and kv_committed_len (derived). + self.has_pending_chunk = False # For retraction self.is_retracted = False @@ -1319,7 +1331,8 @@ def reset_for_retract(self): self.temp_input_top_logprobs_val = None self.temp_input_top_logprobs_idx = None self.extend_logprob_start_len = 0 - self.inflight_middle_chunks = 0 + self.pending_middle_outputs = 0 + self.has_pending_chunk = False self.mamba_pool_idx = None self.mamba_ping_pong_track_buffer = None self.mamba_next_track_idx = None @@ -1335,6 +1348,14 @@ def reset_for_retract(self): self.swa_evicted_seqlen = 0 self.extend_batch_idx = 0 self.decode_batch_idx = 0 + # Disagg-prefill send-side bookkeeping. The pre-v2 retract path never + # ran against a req that had started sending (retract only touched + # running_batch), so these stayed at init values. After v2 added + # pause(retract) coverage for active chunked-resume reqs, a retracted + # disagg-prefill req's stale start_send_idx would index garbage in the + # new row on re-prefill. + self.start_send_idx = 0 + self.tmp_end_idx = -1 # When using input_embeds, we cannot easily mix the original input embeddings # with the newly generated output token IDs during re-prefill of retracted request. @@ -1485,9 +1506,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # This is an optimization to reduce the overhead of the prefill check. batch_is_full: bool = False - # For chunked prefill in PP - chunked_req: Optional[Req] = None - # Sampling info sampling_info: SamplingBatchInfo = None @@ -1628,7 +1646,6 @@ def init_new( model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, - chunked_req: Optional[Req] = None, dllm_config: Optional[DllmConfig] = None, ): return_logprob = any(req.return_logprob for req in reqs) @@ -1654,7 +1671,6 @@ def init_new( return_routed_experts=any(req.return_routed_experts for req in reqs), return_indexer_topk=any(req.return_indexer_topk for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), - chunked_req=chunked_req, dllm_config=dllm_config, ) return batch @@ -1931,6 +1947,13 @@ def prepare_for_extend(self): req._cache_breakdown_computed = True req.already_computed = seq_len + # Reset host_hit_length after init_load_back consumed it so that + # subsequent chunks' admissions skip init_load_back (host KV + # already loaded). Runs unconditionally: post-retract reqs have + # retracted_stain=True (skipping the outer block) but still + # match_prefix + init_load_back on their re-admission, so the + # reset must apply to them too. + req.host_hit_length = 0 req.is_retracted = False if get_global_server_args().enable_mamba_extra_buffer(): @@ -2278,7 +2301,7 @@ def retract_all(self, server_args: ServerArgs): for idx in range(len(self.reqs)): self.release_req(idx, len(self.reqs) - idx, server_args) - self.filter_batch(retracted_reqs) + self.filter_batch(keep_indices=[]) return retracted_reqs def retract_decode( @@ -2486,21 +2509,27 @@ def prepare_for_decode(self): def filter_batch( self, - chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, keep_indices: Optional[List[int]] = None, # FIXME(lsyin): deprecate this API after spec v1 is deprecated v1_spec_info_filtered: Optional[bool] = False, + exclude_chunked_req: bool = False, + exclude_in_flight_other_mb: Optional[set] = None, ): if keep_indices is None: - if isinstance(chunked_req_to_exclude, Req): - chunked_req_to_exclude = [chunked_req_to_exclude] - elif chunked_req_to_exclude is None: - chunked_req_to_exclude = [] + in_flight_rids = exclude_in_flight_other_mb or set() keep_indices = [ i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] not in chunked_req_to_exclude + and not ( + exclude_chunked_req + and ( + self.reqs[i].has_pending_chunk + or self.reqs[i].pending_middle_outputs > 0 + or self.reqs[i].is_dllm() + ) + ) + and self.reqs[i].rid not in in_flight_rids ] if keep_indices is None or len(keep_indices) == 0: @@ -2566,6 +2595,16 @@ def filter_batch( ) def merge_batch(self, other: "ScheduleBatch"): + # Caller must filter_batch(exclude_chunked_req=True) on the other batch + # before merging — running_batch runs decode forward and admitting a + # prefill-in-progress req there breaks shape + KV accounting. Mirror + # the full exclude_chunked_req predicate so PP middle-chunk and DLLM + # staging reqs are also caught here. + assert not any( + r.has_pending_chunk or r.pending_middle_outputs > 0 or r.is_dllm() + for r in other.reqs + ) + # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 1ed7bd9ff437..7654c4c230f4 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -234,6 +234,7 @@ def _compute_prefix_matches( temporary_deprioritized: Set[int] = set() self.waiting_queue_radix_tree.reset() + # C10: chunked-resume no longer in waiting_queue (post-C4); revert to main-upstream sort. for r in waiting_queue: prefix_ids = r.origin_input_ids + r.output_ids extra_key = r.extra_key @@ -277,6 +278,7 @@ def _sort_by_longest_prefix( waiting_queue: List[Req], temporary_deprioritized: Set[int] ) -> None: """Sorts the waiting queue based on the longest prefix match.""" + # C10: chunked-resume no longer in waiting_queue (post-C4); revert to main-upstream sort. waiting_queue.sort( key=lambda r: ( -len(r.prefix_indices) @@ -290,6 +292,7 @@ def _sort_by_dfs_weight( waiting_queue: List[Req], tree_cache: BasePrefixCache ) -> None: """Sorts the waiting queue based on a depth-first search weighting.""" + # C10: chunked-resume no longer in waiting_queue (post-C4); revert to main-upstream sort. last_node_to_reqs = defaultdict(list) for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) @@ -441,7 +444,6 @@ def __init__( self.req_states = None self.can_run_list = [] self.preempt_list = [] - self.new_chunked_req = None self.log_hit_tokens = 0 # TODO(lsyin): report the real input tokens excluding page alignment self.log_input_tokens = 0 @@ -663,41 +665,6 @@ def add_dllm_staging_req(self, req: Req): else AddReqResult.CONTINUE ) - def add_chunked_req(self, req: Req): - if self.dllm_config is not None: - _rem_tokens = self._get_dllm_remain_tokens() - else: - _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens)) - if self.is_hybrid_swa: - # alloc_extend needs extend_num_tokens + page_size per request, - # so reserve one page here to avoid OOM - _rem_tokens = min( - _rem_tokens, int(self.rem_swa_tokens) - self.page_size - ) - # The chunked_req must be added to the list; otherwise, it will cause a memory leak. - # Therefore, in certain cases where _rem_tokens <= 0, it should be replaced with rem_chunk_tokens. - if _rem_tokens <= 0: - if self.is_hybrid_swa: - return req - _rem_tokens = self.rem_chunk_tokens - - truncated = req.extend_input_len > _rem_tokens - req.set_extend_input_len(min(req.extend_input_len, _rem_tokens)) - req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] - self.can_run_list.append(req) - self._update_prefill_budget( - 0, - req.extend_input_len, - ( - min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) - if not truncated - else 0 - ), - ) - - # Return if chunked prefill not finished - return req if truncated else None - @contextmanager def _lock_node(self, last_node: TreeNode): dec_lock_params = None @@ -784,6 +751,7 @@ def add_req_state(r, insert_sort=False): return AddReqResult.OTHER self._add_dllm_req(req, 0) + truncated = False elif ( self.rem_chunk_tokens is None # chunked prefill is disabled or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk @@ -795,6 +763,7 @@ def add_req_state(r, insert_sort=False): req.extend_input_len, min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), ) + truncated = False else: if self.rem_chunk_tokens <= 0: return AddReqResult.OTHER @@ -805,14 +774,23 @@ def add_req_state(r, insert_sort=False): req.set_extend_input_len(trunc_len) req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_chunked_req = req self._update_prefill_budget(0, trunc_len, 0) + truncated = True + + if not req.is_dllm(): + req.has_pending_chunk = truncated return self.budget_state() - def add_one_req( - self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int] - ): + def add_one_req(self, req: Req, truncation_align_size: Optional[int]): + # Reuse path: this req's previous chunk left lock_ref held, prefix + # already in tree, and init_load_back already consumed host KV. We + # must skip fresh-req setup. Gate on `has_pending_chunk` (the + # persistent chunked-resume flag) — `kv_committed_len > 0` alone is + # wider (streaming-session turn N>1 also has it without being + # chunked-resume) and would skip _req_inc_lock_ref incorrectly. + is_resume = req.has_pending_chunk and not req.is_dllm() + if (self.prefill_delayer_single_pass is not None) and ( not self.prefill_delayer_single_pass.negotiate_should_allow_prefill( local_prefillable=True, @@ -881,6 +859,10 @@ def add_one_req( if swa_needed >= self.rem_swa_tokens: return AddReqResult.NO_TOKEN + # Fresh-only init_load_back. For reuse, host_hit_length was set + # on first admission and reset by prepare_for_extend after the + # cache-breakdown metric was computed, so the predicate naturally + # short-circuits here for reuse. if req.host_hit_length > 0: new_indices, req.last_node = self.tree_cache.init_load_back( InitLoadBackParams( @@ -906,6 +888,10 @@ def add_one_req( # - if the can_run_list is empty, always accept the first prefill request return AddReqResult.OTHER + # Budget prefix_len: 0 for reuse (already counted by previous + # admission's stash into tree); actual prefix_len for fresh. + budget_prefix = 0 if is_resume else prefix_len + if self.dllm_config is not None: if self.rem_dllm_tokens <= 0: return AddReqResult.OTHER @@ -916,20 +902,24 @@ def add_one_req( self._add_dllm_req(req, prefix_len) self._req_inc_lock_ref(req) + truncated = False elif self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: - # Non-chunked prefill + # Non-chunked prefill (or last chunk of a chunked-resume req). self.can_run_list.append(req) - self._req_inc_lock_ref(req) + if not is_resume: + self._req_inc_lock_ref(req) self._update_prefill_budget( - prefix_len, + budget_prefix, input_tokens, min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS, ), ) + truncated = False else: + # Chunked prefill: this admission doesn't complete the prefill. # Make sure at least one page is available trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size @@ -954,15 +944,20 @@ def add_one_req( if trunc_len <= 0: return AddReqResult.OTHER - # Chunked prefill req.set_extend_input_len(trunc_len) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] self.can_run_list.append(req) - self.new_chunked_req = req - self._req_inc_lock_ref(req) - self._update_prefill_budget(prefix_len, trunc_len, 0) + if not is_resume: + self._req_inc_lock_ref(req) + self._update_prefill_budget(budget_prefix, trunc_len, 0) + truncated = True + + # has_pending_chunk: persistent flag carrying chunked-resume state + # across iters. DLLM uses its own staging_queue + pending_middle_outputs counter. + if not req.is_dllm(): + req.has_pending_chunk = truncated return self.budget_state() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 107565b2f57b..f8b65e5ca070 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -25,7 +25,7 @@ from contextlib import contextmanager, nullcontext from functools import partial from http import HTTPStatus -from typing import Any, Deque, Dict, List, Optional, Tuple, Union +from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple, Union from sglang.srt.utils.common import suppress_noisy_warnings @@ -609,6 +609,7 @@ def __init__( max_total_num_tokens=self.max_total_num_tokens, get_last_batch=lambda: self.last_batch, get_running_batch=lambda: self.running_batch, + get_active_reqs=lambda: self.active_reqs, ) self.invariant_checker = SchedulerInvariantChecker( @@ -626,6 +627,7 @@ def __init__( pool_stats_observer=self.pool_stats_observer, get_last_batch=lambda: self.last_batch, get_running_batch=lambda: self.running_batch, + get_active_reqs=lambda: self.active_reqs, ) self.kv_events_publisher = SchedulerKvEventsPublisher( @@ -655,7 +657,6 @@ def __init__( get_running_batch=lambda: self.running_batch, get_waiting_queue=lambda: self.waiting_queue, get_stats=lambda: self.metrics_reporter.stats, - get_chunked_req=lambda: self.chunked_req, get_disagg_prefill_bootstrap_queue=lambda: self.disagg_prefill_bootstrap_queue, get_disagg_prefill_inflight_queue=lambda: self.disagg_prefill_inflight_queue, get_disagg_decode_prealloc_queue=lambda: self.disagg_decode_prealloc_queue, @@ -697,6 +698,7 @@ def __init__( ), output_streamer=self.output_streamer, abort_request=self.abort_request, + deactivate_req=self._deactivate, ) self.is_initializing = False @@ -995,6 +997,30 @@ def init_model_worker(self): def init_running_status(self): self.waiting_queue: List[Req] = [] + # `active_reqs`: sync-mode reqs the scheduler currently owns the + # lifecycle of (admitted, not finished, not retracted, not aborted- + # released). by-rid indexed. + # + # Definition (Plan §7-Q7): admitted via `_get_new_batch_prefill_raw` + # and not yet released through finish/retract/abort. Includes normal + # decode reqs AND mid-prefill chunked-resume reqs AND PP cross-mb + # in-flight reqs (the last two: NOT in running_batch.reqs but still + # holding row + KV + lock_ref). + # + # Invariants: + # * `waiting_queue ∩ active_reqs == ∅` (sync mode; disagg modes use + # their own ownership managers, see Q1=(c)). + # * `set(running_batch.reqs) ⊆ active_reqs` (in-batch always active). + # * `set(chunked_reqs()) ⊆ active_reqs` (by definition). + # * `len(list(chunked_reqs())) <= 1` (Q5 single-flight; asserted at + # inline chunked admission entry). + # * `active_reqs` keys are in 1:1 correspondence with allocated + # `req_to_token_pool` rows (sync mode). + # + # Maintained at: `_activate` / `_deactivate` (only entry points). + # See agent-drafts/2026-05-25-waiting-queue-refactor-plan.md and + # 2026-05-25-scheduler-lifecycle-audit.md. + self.active_reqs: Dict[str, Req] = {} # The running decoding batch for continuous batching self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) # The current forward batch @@ -1012,6 +1038,77 @@ def init_running_status(self): self.forward_sleep_time = None self._engine_paused = False + def _activate(self, req: Req) -> None: + """Mark req as entering active lifecycle. + + Gated: only sync mode + disagg PREFILL + non-DLLM reqs enter + active_reqs. Disagg DECODE has its own prealloc/transfer queue + ownership; DLLM has its own staging_queue. See plan §2 Scope and + C10 fix plan §2 (disagg PREFILL bug). + """ + if self.disaggregation_mode == DisaggregationMode.DECODE: + return + if req.is_dllm(): + return + assert req.rid not in self.active_reqs, f"already active: {req.rid}" + self.active_reqs[req.rid] = req + + def _deactivate(self, req: Req) -> None: + """Mark req as leaving active lifecycle (finish / abort / retract). + + Important: this function ONLY pops from active_reqs dict. + - Does not clear req.req_pool_idx: batch_result_processor.py:774-787 PP + cross-mb idempotency guard relies on it as an "already released" + sentinel. + - Does not clear req.has_pending_chunk / req.pending_middle_outputs: + owned by the semantic finish/abort/retract sites. + - Does not call release_kv_cache: that is the responsibility of + release_req / abort / finish paths. + This function only answers "scheduler no longer owns this req's + lifecycle". + """ + self.active_reqs.pop(req.rid, None) + + def _assert_invariants(self) -> None: + """Debug-only invariant checks for active_reqs ownership tracking. + + Gated by DEBUG_INVARIANTS=1 to avoid slowing down normal runs. Skipped + in disagg modes (Q1=(c): disagg has its own ownership model). + """ + if not os.environ.get("DEBUG_INVARIANTS"): + return + if self.disaggregation_mode != DisaggregationMode.NULL: + return + waiting_rids = {r.rid for r in self.waiting_queue} + active_rids = set(self.active_reqs.keys()) + running_rids = {r.rid for r in self.running_batch.reqs} + + # sync mode: waiting_queue and active_reqs are strictly disjoint + # (C4 removed chunked-resume retention; chunked-resume now lives in + # active_reqs only). + assert not waiting_rids & active_rids, ( + f"waiting_queue and active_reqs must be disjoint (sync mode); " + f"overlap: {waiting_rids & active_rids}" + ) + + assert ( + running_rids <= active_rids + ), f"running not subset of active: {running_rids - active_rids}" + + def chunked_reqs(self) -> Iterable[Req]: + """Active reqs currently in mid-prefill (`has_pending_chunk=True`). + + Derived view over `active_reqs` — no separate storage. Single-flight + invariant (Q5): `len(list(chunked_reqs())) <= 1` at any iter + boundary; asserted at the entry of the inline chunked admission + block in `_get_new_batch_prefill_raw`. + + Iteration semantics: returns a fresh generator each call; consume + once or wrap in `list(...)`. Callers that mutate `active_reqs` + during iteration must `list(...)` first. + """ + return (r for r in self.active_reqs.values() if r.has_pending_chunk) + def init_chunked_prefill(self): self.chunked_prefill_size = self.server_args.chunked_prefill_size uses_transformers_backend = ( @@ -1030,16 +1127,13 @@ def init_chunked_prefill(self): self.chunked_prefill_size = None elif self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0: self.chunked_prefill_size = None - self.chunked_req = None - # Tracks whether the current self.chunked_req was actually scheduled - # into last iteration's batch (i.e., in can_run_list -> got a fresh - # req_pool_idx from prepare_for_extend). Used to gate the - # stash_chunked_request call at the top of get_next_batch_to_run: - # if add_chunked_req early-returned under hybrid-SWA pressure, - # the req_pool_idx was already freed and fill_ids was reset by - # init_next_round_input, so running stash would double-free and - # corrupt prefix_indices. - self._chunked_req_scheduled_last_iter = False + # Chunked-resume tracking: per-Req (has_pending_chunk + + # pending_middle_outputs). After the C1-C7 refactor, chunked-resume + # reqs live exclusively in `active_reqs` (not waiting_queue); Stage A + # iterates `chunked_reqs()` derived from active_reqs. The inline + # chunked admission block at the top of `_get_new_batch_prefill_raw` + # re-admits them each iter. See agent-drafts/ + # 2026-05-25-waiting-queue-refactor-plan.md. self.is_mixed_chunk = ( self.chunked_prefill_size is not None and self.server_args.enable_mixed_chunk @@ -2155,6 +2249,7 @@ def _abort_on_waiting_timeout(self): deleted_reqs = set() deadline = time.perf_counter() - timeout_s + # audit AB7: chunked-resume no longer in waiting_queue (C4), bypass removed. for req in self.waiting_queue: entry_time = req.time_stats.wait_queue_entry_time if 0 < entry_time < deadline: @@ -2257,9 +2352,6 @@ def handle_batch_embedding_request( for tokenized_req in recv_req: self.handle_embedding_request(tokenized_req) - def stash_chunked_request(self, req: Req): - maybe_cache_unfinished_req(req, self.tree_cache, chunked=True) - def _build_hisparse_decode_batch(self, reqs): """Build a ScheduleBatch for hisparse requests transitioning from staging to decode.""" device = self.device @@ -2296,7 +2388,43 @@ def _build_hisparse_decode_batch(self, reqs): # todo hisparse, maybe other info to contain for the new batch return batch + def _in_flight_other_mb_rids(self) -> set: + """rids of reqs whose chunked-prefill forward is launched in another + PP microbatch but whose result has not yet been processed by the + output processor — AND for which a follow-up decode would actually + propagate corruption (max_new_tokens > 1). + + In PP+chunked-prefill, mb_a's LAST chunk admit clears has_pending_chunk + on the req while mb_a's chunk forward result is still in flight. If + mb_b's filter_batch merges this req into running_batch, mb_b's decode + forward runs on stale state — input falls back to origin[-1] and + writes WRONG K,V at row position N. The wrong K,V at N persists in + the KV pool and corrupts every subsequent decode position. + + For req.sampling_params.max_new_tokens == 1, the wrong decode result + is filtered by `req.finished()` (line ~240) before being appended, + and the wrong K,V at N is released with the rest of the row when + the req finishes — no observable effect. Excluding such reqs would + delay them by 1 mb step for no correctness gain, so we skip them + here and only return rids of reqs that genuinely need protection. + """ + if self.ps.pp_size <= 1 or not hasattr(self, "mbs"): + return set() + rids = set() + for mb in self.mbs: + if mb is None or mb is self.last_batch: + continue + for r in mb.reqs: + # max_new_tokens is normalized to a non-None int in + # _prepare_input_for_image_request / similar paths during + # request admission, but defensively handle missing/zero. + max_new = r.sampling_params.max_new_tokens or 0 + if max_new > 1: + rids.add(r.rid) + return rids + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + self._assert_invariants() if self.enable_fpm: self._fpm_batch_t0 = time.monotonic() self._abort_on_waiting_timeout() @@ -2304,21 +2432,30 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: if self.dllm_config is not None: self.dllm_manager.filter_finished_reqs() - # Merge the prefill batch into the running batch - chunked_req_to_exclude = set() + # Stage A: stash any in-flight chunked prefill KV into radix tree. + # Per-req loop over waiting_queue covers chunked-resume; DLLM staging + # reqs are owned by DllmManager (not in waiting_queue), handled + # separately below. + # + # Why this runs at the iter boundary (not at the end of the prior iter): + # admission inside get_new_batch_prefill_raw reads req.prefix_indices to + # decide extend_input_len. Stashing in the middle of admission would let + # a chunked-resume req "match itself" — the tree would expose KV this + # same req just wrote, double-counting it as cached prefix. Keeping + # stash here means admission only ever sees tree state that is stable + # for the duration of the scheduling pass. vLLM / TokenSpeed do not + # need this because their admission reads a single monotone counter + # (num_computed_tokens / FSM state), not a prefix-indices splice. + # audit P1: Stage A — stash chunked-resume KV into radix tree at iter + # boundary. Switch from scanning waiting_queue (H3 hack) to iterating + # the chunked_reqs() view directly. + for req in self.chunked_reqs(): + if not req.is_dllm(): + maybe_cache_unfinished_req(req, self.tree_cache, chunked=True) if self.dllm_config is not None and self.dllm_manager.any_staging_reqs(): - chunked_req_to_exclude.update(self.dllm_manager.staging_queue) for req in self.dllm_manager.staging_queue: - self.stash_chunked_request(req) - - if self.chunked_req is not None: - # Move the chunked request out of the batch so that we can merge - # only finished requests to running_batch. - chunked_req_to_exclude.add(self.chunked_req) - - if self._chunked_req_scheduled_last_iter: - self.stash_chunked_request(self.chunked_req) + maybe_cache_unfinished_req(req, self.tree_cache, chunked=True) # HiSparse has its own prefill-to-decode transition; skip last_batch merge. if self.enable_hisparse: @@ -2338,18 +2475,19 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: and self.last_batch and self.last_batch.forward_mode.is_extend() ): - if self.last_batch.chunked_req is not None: - # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. - # We need to discard it. - chunked_req_to_exclude.add(self.last_batch.chunked_req) - - if self.dllm_config is not None and self.last_batch.reqs: - chunked_req_to_exclude.update(self.last_batch.reqs) - - # Filter batch last_bs = self.last_batch.batch_size() + # Drop chunked-resume reqs before merging last_batch into + # running_batch. running_batch runs decode forward and admitting + # a mid-prefill req there breaks shapes + KV accounting. The + # dropped reqs persist in self.active_reqs and re-enter via the + # inline chunked admission in _get_new_batch_prefill_raw. + # + # PP cross-mb: also drop reqs whose LAST chunk forward is still + # in flight in another mb (when more decodes will follow — i.e., + # max_new_tokens > 1). See _in_flight_other_mb_rids for rationale. self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) + exclude_chunked_req=True, + exclude_in_flight_other_mb=self._in_flight_other_mb_rids(), ) if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False @@ -2368,7 +2506,16 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Runs outside the last_batch block so stale requests are cleaned # even when no new batches arrive (e.g. traffic stops). if self.running_batch.is_prefill_only: - self.running_batch.filter_batch() + # Defensive exclude_chunked_req: the merge step above already + # drops chunked-resume reqs from last_batch, so running_batch + # shouldn't normally hold one. Keep the flag set so any leak in + # that invariant doesn't survive here; the dropped req remains + # in active_reqs (post-C4) and is re-admitted next iter via the + # inline chunked admission block in _get_new_batch_prefill_raw. + self.running_batch.filter_batch( + exclude_chunked_req=True, + exclude_in_flight_other_mb=self._in_flight_other_mb_rids(), + ) if self.running_batch.is_empty(): self.running_batch.batch_is_full = False @@ -2417,6 +2564,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: if self.enable_fpm: ret.fpm_start_time = self._fpm_batch_t0 + self._assert_invariants() return ret def get_num_allocatable_reqs(self, running_bs): @@ -2447,6 +2595,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def _get_new_batch_prefill_raw( self, prefill_delayer_single_pass: Optional[PrefillDelayerSinglePassExecutor] ) -> Optional[ScheduleBatch]: + # Chunked-resume admission: handled by the small block at the top of this + # method, which feeds the single chunked-resume req (if any) through + # `adder.add_one_req`. PrefillAdder.add_one_req detects chunked-resume via + # the `is_resume` flag (has_pending_chunk and not is_dllm) and handles all + # budget bookkeeping in one place — no special add_chunked_req method + # resurrected. The main waiting_queue loop below admits ONLY truly-waiting + # reqs. See agent-drafts/2026-05-25-waiting-queue-refactor-plan.md §C3 (and + # C9 follow-up). # Check if the grammar is ready in the grammar queue if self.grammar_manager.has_waiting_grammars(): ready_grammar_requests = self.grammar_manager.get_ready_grammar_requests() @@ -2460,21 +2616,30 @@ def _get_new_batch_prefill_raw( # Reset batch_is_full to try preemption with a prefill adder. self.running_batch.batch_is_full = False + # audit H4 + Q5: chunked-resume now lives in active_reqs (not + # waiting_queue, post-C4). Compute the single-flight view once here + # and reuse below for early-exit relaxation, dynamic chunking, and + # the inline chunked admission entry. + chunked_in_active = list(self.chunked_reqs()) + assert len(chunked_in_active) <= 1, ( + f"single-flight violated: {len(chunked_in_active)} chunked reqs " + f"in active ({[r.rid for r in chunked_in_active]})" + ) + if ( self.running_batch.batch_is_full or len(self.waiting_queue) == 0 - ) and self.chunked_req is None: + ) and not chunked_in_active: return None running_bs = len(self.running_batch.reqs) - # Ignore the check if self.chunked_req is not None. - # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0, - # as the space for the chunked requests has just been released. - # In PP case, chunked requests (or dllm requests) can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict. - # Instead, we should always allow chunked requests to be added, otherwise, there will be a memory leak. + # Ignore the check if there is a chunked-resume in flight. + # In the non-PP case the row was just released so the count is fine; + # in PP case, chunked reqs span microbatches so the per-mb max_running + # check should not block them. if ( self.get_num_allocatable_reqs(running_bs) <= 0 - and self.chunked_req is None + and not chunked_in_active and not self.enable_priority_preemption ): self.running_batch.batch_is_full = True @@ -2491,8 +2656,12 @@ def _get_new_batch_prefill_raw( # Determine chunked_prefill_size for this batch chunked_prefill_size = self.chunked_prefill_size - if self.chunked_req is not None and self.enable_dynamic_chunking: - history_len = len(self.chunked_req.prefix_indices) + if self.enable_dynamic_chunking and chunked_in_active: + # audit H5: chunked-resume lives in active_reqs; reuse the + # single-flight view computed above instead of scanning + # waiting_queue. + chunked_resume = chunked_in_active[0] + history_len = len(chunked_resume.prefix_indices) dynamic_size = self.predict_next_chunk_size(history_len) if dynamic_size is not None: chunked_prefill_size = dynamic_size @@ -2516,14 +2685,23 @@ def _get_new_batch_prefill_raw( waiting_queue_len=len(self.waiting_queue), ) - if self.chunked_req is not None: - self.chunked_req.init_next_round_input() - self.chunked_req = adder.add_chunked_req(self.chunked_req) - self._chunked_req_scheduled_last_iter = ( - self.chunked_req in adder.can_run_list + if chunked_in_active: + chunked_req = chunked_in_active[0] + # No tree_cache: chunked-resume MUST NOT re-match prefix (H7). + # Its row + KV + lock_ref are already held from prior admission. + chunked_req.init_next_round_input() + # Use the standard adder.add_one_req — its `is_resume` branch + # (schedule_policy.py:811) handles chunked-resume correctly: + # - budget_prefix=0 (don't double-count prefix) + # - skip _req_inc_lock_ref (already held) + # - update has_pending_chunk = truncated + # By running BEFORE the main waiting_queue loop, the chunked req + # also skips LoRA drainer / hicache prefetch checks that the + # main loop applies to fresh reqs. + adder.add_one_req( + chunked_req, + truncation_align_size=self.truncation_align_size, ) - else: - self._chunked_req_scheduled_last_iter = False if self.enable_lora: running_loras = {req.lora_id for req in self.running_batch.reqs} @@ -2536,6 +2714,8 @@ def _get_new_batch_prefill_raw( # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: + # audit H6: chunked-resume no longer flows through main loop; + # drainer check applies uniformly. if self.enable_lora and not self._can_schedule_lora_req(req, running_loras): continue @@ -2565,10 +2745,11 @@ def _get_new_batch_prefill_raw( req.rid ) + # audit H7: chunked-resume handled in inline admission above; + # main loop unconditional. req.init_next_round_input(self.tree_cache) res = adder.add_one_req( req, - has_chunked_req=(self.chunked_req is not None), truncation_align_size=self.truncation_align_size, ) @@ -2588,9 +2769,13 @@ def _get_new_batch_prefill_raw( # Only free if the slot was freshly allocated in this batch (not # pre-existing from a session). Session-held slots have their own # lifecycle and freeing them here causes double-free. + # Chunked-resume reqs inherit mamba_pool_idx from their first + # admission; freeing it on a transient NO_TOKEN this iter would + # discard a live mamba state still needed by subsequent chunks. added = len(adder.can_run_list) > 0 and req is adder.can_run_list[-1] if ( not added + and not req.has_pending_chunk and req.mamba_pool_idx is not None and not getattr(req, "session", None) ): @@ -2605,22 +2790,44 @@ def _get_new_batch_prefill_raw( if len(can_run_list) == 0: return None + # audit A1: mark newly-admitted reqs as active. Post-C3/C4 the main + # admission loop (the for-loop over waiting_queue above) only + # produces brand-new admissions. The inline chunked admission block + # also appends to `can_run_list` for chunked-resume re-admit, and + # those reqs are already in active_reqs from a prior iter (the + # inline block does NOT call _activate). Skip them here so the + # strict `_activate` assert (post-C7) catches accidental + # double-admission for everything else. + for req in can_run_list: + if req.rid in self.active_reqs: + continue + self._activate(req) + + # audit H2: retention removed. chunked-resume reqs are no longer + # anchored in waiting_queue — they live in active_reqs and are + # re-admitted via the inline chunked admission loop (C3). can_run_set = set(can_run_list) self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_set] if adder.preempt_list: for req in adder.preempt_list: + # audit R2: PrefillAdder.preempt_to_schedule already released + # the victim's resources via running_batch.release_req. Drop + # from active_reqs before re-enqueueing as a waiting req. + self._deactivate(req) self._add_request_to_queue(req) - if adder.new_chunked_req is not None: - # Update chunked prefill - assert self.chunked_req is None - self.chunked_req = adder.new_chunked_req - # new_chunked_req is added to can_run_list by add_one_req, - # so it will be scheduled this iter -> stash is needed next iter. - self._chunked_req_scheduled_last_iter = True - - if self.chunked_req is not None: - self.chunked_req.inflight_middle_chunks += 1 + # Bump pending_middle_outputs for every admitted req that's still + # mid-prefill — output processor uses this to know its forward's + # sample is garbage. Counter semantics needed for PP, where multiple + # microbatches may admit the same req. + chunked_in_batch = [r for r in can_run_list if r.has_pending_chunk] + assert ( + len(chunked_in_batch) <= 1 + ), "single-flight invariant: at most one chunked-resume req per batch" + chunk_deduct = 0 + for r in chunked_in_batch: + r.pending_middle_outputs += 1 + chunk_deduct = r.extend_input_len set_time_batch(can_run_list, "set_forward_entry_time") @@ -2633,7 +2840,6 @@ def _get_new_batch_prefill_raw( self.model_config, self.enable_overlap, self.spec_algorithm, - chunked_req=self.chunked_req, ) self.max_prefill_bs = max(self.max_prefill_bs, len(can_run_list)) if self.enable_hierarchical_cache: @@ -2650,11 +2856,7 @@ def _get_new_batch_prefill_raw( self.running_batch.reqs, self.enable_priority_scheduling, num_pending_tokens=self.load_inquirer._get_num_pending_tokens( - chunk_deduct=( - self.chunked_req.extend_input_len - if self.chunked_req is not None - else 0 - ), + chunk_deduct=chunk_deduct ), ) @@ -2667,7 +2869,13 @@ def _get_new_batch_prefill_raw( and new_batch.input_embeds is None ): # TODO (lianmin): support return_logprob + mixed chunked prefill - self.running_batch.filter_batch(v1_spec_info_filtered=True) + # exclude_chunked_req here is defensive — by design running_batch + # holds decode reqs only (the last_batch filter+merge step above + # already drops chunked-resume), and any dropped chunked-resume + # would still ride waiting_queue retention to next iter's Stage A. + self.running_batch.filter_batch( + v1_spec_info_filtered=True, exclude_chunked_req=True + ) if not self.running_batch.is_empty(): self.running_batch.prepare_for_decode() new_batch.mix_with_running(self.running_batch) @@ -2779,7 +2987,13 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: ) logger.warning(msg_prefix + msg_details) + # audit R1: retract_decode released row + KV via release_req for + # both retracted_reqs (re-enqueued as waiting) and reqs_to_abort + # (final OOM eviction). Drop both from active_reqs. + for req in reqs_to_abort: + self._deactivate(req) for req in retracted_reqs: + self._deactivate(req) self._add_request_to_queue(req, is_retracted=True) else: self.new_token_ratio_tracker.decay_step() @@ -3149,7 +3363,6 @@ def is_fully_idle(self, for_health_check=False) -> bool: # Batch running status idle = ( self.running_batch.is_empty() - and self.chunked_req is None and not self.dllm_manager.any_staging_reqs() and (self.last_batch is None or self.last_batch.is_empty()) and (self.cur_batch is None or self.cur_batch.is_empty()) @@ -3294,6 +3507,7 @@ def flush_cache(self, empty_cache: bool = True): self.last_batch = None self.tree_cache.reset() self.req_to_token_pool.clear() + self.active_reqs.clear() # audit: keep parallel to req_to_token_pool reset (C8) self.token_to_kv_pool_allocator.clear() self.grammar_manager.clear() self.metrics_reporter.reset_metrics() @@ -3423,10 +3637,34 @@ def handle_rpc_request(self, recv_req: RpcReqInput): def abort_request(self, recv_req: AbortReq): # todo hisparse, release resources for abort requests in hisparse coordinator + # Post-C4: chunked-resume reqs live in active_reqs only, never in waiting_queue. + if self.cur_batch is self.running_batch or self.cur_batch is None: + batch_reqs = list(self.running_batch.reqs) + else: + batch_reqs = list(self.running_batch.reqs) + list(self.cur_batch.reqs) + # PP: rids from every in-flight microbatch must also be treated as + # 'in batch'. Each mb's forward was launched against the req's + # req_pool_idx + KV slots; the output processor on a different mb + # iteration consumes the result later. Without this, a chunked-resume + # req with pending_middle_outputs > 0 sitting in waiting_queue would + # fall into the waiting-only abort path, release_kv_cache would free + # the row + KV underneath the still-launched forward, and the delayed + # output processor would crash on a None req_pool_idx (or, with + # pending_middle_outputs cleared to 0, mistake the middle-chunk + # result for a full output and append garbage tokens). + if self.ps.pp_size > 1 and hasattr(self, "mbs"): + for mb_list in (self.mbs, self.last_mbs, self.running_mbs): + for mb in mb_list: + if mb is not None and not mb.is_empty(): + batch_reqs.extend(mb.reqs) + batch_rids = {r.rid for r in batch_reqs} + # Delete requests in the waiting queue to_del = [] for i, req in enumerate(self.waiting_queue): - if recv_req.abort_all or req.rid.startswith(recv_req.rid): + if (recv_req.abort_all or req.rid.startswith(recv_req.rid)) and ( + req.rid not in batch_rids + ): to_del.append(i) # Sort in reverse order to avoid index issues when deleting @@ -3448,12 +3686,19 @@ def abort_request(self, recv_req: AbortReq): req, self.req_to_metadata_buffer_idx_allocator ) - # For mamba radix cache + # audit AB4 simplified post-C4: only mamba radix cache reqs can be + # in waiting_queue with mamba_pool_idx held. Chunked-resume reqs + # are NOT in waiting_queue anymore (live in active_reqs); their + # abort-time release happens in the active_reqs loop below. if ( req.mamba_pool_idx is not None and self.disaggregation_mode != DisaggregationMode.DECODE ): release_kv_cache(req, self.tree_cache, is_insert=False) + # audit D6 (mamba branch): drop from active set if present. + # (mamba-radix path may or may not put req in active_reqs; + # _deactivate is idempotent.) + self._deactivate(req) logger.debug(f"Abort queued request. {req.rid=}") # Delete the requests in the grammar queue @@ -3505,21 +3750,61 @@ def abort_request(self, recv_req: AbortReq): remaining_retracted.append(decode_req) self.disagg_decode_prealloc_queue.retracted_queue = remaining_retracted - # Delete requests in the running batch - if self.cur_batch is self.running_batch or self.cur_batch is None: - reqs = self.running_batch.reqs - else: - reqs = self.running_batch.reqs + self.cur_batch.reqs + # audit finding 2 (Plan §C6 Edit 3): iterate active_reqs instead of + # batch_reqs so that stashed chunked-resume reqs (in active_reqs but + # NOT in any current batch) get their resources released immediately. + # batch_rids was built above and includes cur_batch + running_batch + + # PP mbs[*]; "in-batch" reqs go through to_finish, "stashed-chunked" + # reqs need explicit release because no batch result path will pick + # them up. + for rid in list(self.active_reqs.keys()): + req = self.active_reqs[rid] + if req.finished(): + continue + if not (recv_req.abort_all or rid.startswith(recv_req.rid)): + continue - for req in reqs: - if not req.finished() and ( - recv_req.abort_all or req.rid.startswith(recv_req.rid) - ): - # Abort method 3: set `to_finish` - # The request will still run one decode forward pass. - # Then we reuse all existing code to clean up the KV cache allocation. + if rid in batch_rids: + # In some batch: standard to_finish path; release_kv_cache + + # _deactivate happen in process_batch_result_*. logger.debug(f"Abort running request. {req.rid=}") req.to_finish = FINISH_ABORT() + else: + # Active but not in any batch — the only legitimate case is + # a stashed chunked-resume mid-prefill (audit finding 2). + # Release immediately, else row+KV+lock_ref leak. + assert req.has_pending_chunk and req.req_pool_idx is not None, ( + f"unexpected active-but-not-in-batch req: {rid} " + f"has_pending_chunk={req.has_pending_chunk} " + f"req_pool_idx={req.req_pool_idx}" + ) + if self.disaggregation_mode != DisaggregationMode.DECODE: + # C11: disagg PREFILL stashed-chunked req has already been + # sending KV chunks to the peer decode node. Signal abort so + # the peer doesn't wait forever for the remaining chunks. + # Mirrors pause_generation(retract) PREFILL handling + # (scheduler.py pause section). + if ( + self.disaggregation_mode == DisaggregationMode.PREFILL + and req.disagg_kv_sender is not None + ): + if hasattr(req.disagg_kv_sender, "abort"): + req.disagg_kv_sender.abort() + req.disagg_kv_sender = None + + release_kv_cache(req, self.tree_cache, is_insert=False) + + # C11: PREFILL mode also needs to release the metadata buffer + # slot. Mirrors abort_request waiting-segment PREFILL handling. + if self.disaggregation_mode == DisaggregationMode.PREFILL: + release_req_to_metadata_buffer( + req, self.req_to_metadata_buffer_idx_allocator + ) + + req.has_pending_chunk = False + req.pending_middle_outputs = 0 + self._deactivate(req) + logger.debug(f"Abort stashed chunked-resume request. {req.rid=}") def _pause_engine(self) -> Tuple[List[Req], int]: raise NotImplementedError() @@ -3529,11 +3814,11 @@ def pause_generation(self, recv_req: PauseGenerationReqInput): if recv_req.mode == "in_place": # In-place pause: just set the flag and return immediately. - # All scheduler state (running_batch, last_batch, chunked_req, + # All scheduler state (running_batch, last_batch, waiting_queue, # result_queue) is left untouched. On resume, the normal event # loop (get_next_batch_to_run) handles last_batch merge, - # chunked_req cleanup, and overlap result processing through - # the standard code paths. This avoids duplicating batch + # chunked-resume re-admission, and overlap result processing + # through the standard code paths. This avoids duplicating batch # manipulation logic and the accounting bugs that come with it. return @@ -3543,10 +3828,10 @@ def pause_generation(self, recv_req: PauseGenerationReqInput): self.process_batch_result(tmp_batch, tmp_result) if self.last_batch and self.last_batch.forward_mode.is_extend(): - chunked_req_to_exclude = set() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) + # Same invariant as the non-disagg merge path: drop chunked-resume + # reqs before potentially folding last_batch into running_batch. + # They re-enter via waiting_queue retention + Stage A next iter. + self.last_batch.filter_batch(exclude_chunked_req=True) # Skip merge for disagg prefill: completed prefill requests are # already in disagg_prefill_inflight_queue. Merging them into # running_batch leaks them, since the prefill event loop never @@ -3563,15 +3848,56 @@ def pause_generation(self, recv_req: PauseGenerationReqInput): self.last_batch = None self.cur_batch = None - if recv_req.mode == "retract" and not self.running_batch.is_empty(): - self.running_batch.filter_batch(v1_spec_info_filtered=True) - if len(self.running_batch.reqs) != 0: - retracted_reqs = self.running_batch.retract_all(self.server_args) - for req in retracted_reqs: - self._add_request_to_queue(req) + if recv_req.mode == "retract": + if not self.running_batch.is_empty(): + self.running_batch.filter_batch(v1_spec_info_filtered=True) + if len(self.running_batch.reqs) != 0: + retracted_reqs = self.running_batch.retract_all(self.server_args) + # audit R3: retract_all released resources via release_req + # for every running req; drop from active_reqs before + # re-enqueueing as waiting. + for req in retracted_reqs: + self._deactivate(req) + self._add_request_to_queue(req) + + self.running_batch.batch_is_full = False - self.running_batch.batch_is_full = False - self.chunked_req = None + # Chunked-resume reqs still hold their row + KV + radix lock_ref + # from prior admissions. Without explicit release, pause(retract)'s + # 'flush_cache can succeed' contract (see PauseGenerationReqInput + # docstring) is violated. Release in-place and reset their chunked + # state so continue_generation re-prefills them from + # origin_input_ids. + # + # audit C7: chunked-resume lives in active_reqs (post-C4), + # iterate chunked_reqs() directly. list(...) because we mutate + # active_reqs via _deactivate inside the loop. + for req in list(self.chunked_reqs()): + if req.req_pool_idx is not None: + # Disagg-prefill: signal the decode side that the send was + # retracted and drop our sender ref so re-prefill rebuilds + # the bootstrap state. start_send_idx / tmp_end_idx are + # reset by reset_for_retract. + if ( + self.disaggregation_mode == DisaggregationMode.PREFILL + and req.disagg_kv_sender is not None + ): + if hasattr(req.disagg_kv_sender, "abort"): + req.disagg_kv_sender.abort() + req.disagg_kv_sender = None + release_kv_cache(req, self.tree_cache, is_insert=False) + req.reset_for_retract() + # audit D7: chunked-resume req released via reset_for_retract + # no longer holds row/KV, so it leaves the active set. + self._deactivate(req) + # TODO(post-refactor follow-up): plan §10 flag — after + # reset_for_retract, this req is NOT re-enqueued to + # waiting_queue. Either the design relies on the original + # reference staying in waiting_queue (but C4 removed + # retention!) or this is a pre-existing latent bug from + # before the refactor. Investigate separately. See + # agent-drafts/2026-05-25-waiting-queue-refactor-plan.md + # §10. def continue_generation(self, recv_req: ContinueGenerationReqInput): if recv_req.torch_empty_cache: diff --git a/python/sglang/srt/managers/scheduler_components/batch_result_processor.py b/python/sglang/srt/managers/scheduler_components/batch_result_processor.py index 5e5e57cacf56..06f8aa29826e 100644 --- a/python/sglang/srt/managers/scheduler_components/batch_result_processor.py +++ b/python/sglang/srt/managers/scheduler_components/batch_result_processor.py @@ -78,6 +78,7 @@ class SchedulerBatchResultProcessor: logprob_result_processor: "SchedulerLogprobResultProcessor" output_streamer: "SchedulerOutputStreamer" abort_request: Callable + deactivate_req: Callable def process_batch_result_prebuilt(self, batch: ScheduleBatch): assert self.disaggregation_mode == DisaggregationMode.DECODE @@ -218,7 +219,7 @@ def process_batch_result_prefill( # decode req in mixed batch or retracted req continue - if req.inflight_middle_chunks <= 0: + if req.pending_middle_outputs <= 0: req.time_stats.set_prefill_finished_time() # req output_ids are set here @@ -231,6 +232,8 @@ def process_batch_result_prefill( self._maybe_collect_routed_experts(req) self._maybe_collect_indexer_topk(req) release_kv_cache(req, self.tree_cache) + # audit D1: sync prefill finish + self.deactivate_req(req) req.time_stats.set_completion_time() elif not batch.decoding_reqs or req not in batch.decoding_reqs: maybe_cache_unfinished_req(req, self.tree_cache) @@ -267,7 +270,7 @@ def process_batch_result_prefill( else: # being chunked reqs' prefill is not finished - req.inflight_middle_chunks -= 1 + req.pending_middle_outputs -= 1 # There is only at most one request being currently chunked. # Because this request does not finish prefill, # we don't want to stream the request currently being chunked. @@ -307,7 +310,7 @@ def process_batch_result_prefill( req.embedding = embeddings[i] if req.return_pooled_hidden_states and phs is not None: req.pooled_hidden_state = phs[i] - if req.inflight_middle_chunks <= 0: + if req.pending_middle_outputs <= 0: req.time_stats.set_prefill_finished_time() # Dummy output token for embedding models req.output_ids.append(0) @@ -315,12 +318,14 @@ def process_batch_result_prefill( if req.finished(): release_kv_cache(req, self.tree_cache) + # audit D2: embedding/reward prefill finish + self.deactivate_req(req) req.time_stats.set_completion_time() else: maybe_cache_unfinished_req(req, self.tree_cache) else: # being chunked reqs' prefill is not finished - req.inflight_middle_chunks -= 1 + req.pending_middle_outputs -= 1 req.time_stats.set_last_chunked_prefill_finish_time() self.output_streamer.stream_output( @@ -772,6 +777,20 @@ def _handle_finished_req( self.decode_offload_manager.offload_kv_cache(req) if req.finished(): + # Idempotency guard for PP cross-microbatch races: in PP+chunked + # prefill the same Req object can sit in multiple in-flight + # mbs[*] batches when chunks of one req are pipelined across + # microbatch slots. The slot that processes the last chunk's + # result finalizes the req (release_kv_cache nulls req_pool_idx), + # then a sibling slot's pending result hits the same req again + # here and would trip the assert in release_kv_cache. Treat + # `req_pool_idx is None at finalize` as "already released" and + # skip the redundant cleanup; the first call already collected + # multimodal/experts/indexer/time-stats state. + if req.req_pool_idx is None and not self.tree_cache.supports_mamba(): + self._maybe_collect_customized_info(i, req, logits_output) + return + # delete feature to save memory if req.multimodal_inputs is not None and req.session is None: req.multimodal_inputs.release_features() @@ -786,6 +805,10 @@ def _handle_finished_req( if self.server_args.enable_hisparse: self.hisparse_coordinator.request_finished(req) release_kv_cache(req, self.tree_cache) + # audit D3: sync decode finish (non-offload path). The DECODE + # offload branch (D4) does not call _deactivate — disagg DECODE + # is not in active_reqs (Q1=(c)). + self.deactivate_req(req) req.time_stats.set_completion_time() diff --git a/python/sglang/srt/managers/scheduler_components/invariant_checker.py b/python/sglang/srt/managers/scheduler_components/invariant_checker.py index 237a5a60675e..8a236aba68d1 100644 --- a/python/sglang/srt/managers/scheduler_components/invariant_checker.py +++ b/python/sglang/srt/managers/scheduler_components/invariant_checker.py @@ -50,6 +50,7 @@ class SchedulerInvariantChecker: pool_stats_observer: SchedulerPoolStatsObserver get_last_batch: Callable get_running_batch: Callable + get_active_reqs: Callable count_req_pool_leak_warnings: int = 0 count_memory_leak_warnings: int = 0 @@ -156,17 +157,32 @@ def _get_total_uncached_sizes( """ # After decode: running_batch IS last_batch (same object), count once. # After prefill: they differ, both hold uncached tokens. - batches = [self.get_last_batch()] + req_groups = [list(self.get_last_batch().reqs)] if ( self.get_running_batch() not in (None, self.get_last_batch()) and not self.get_running_batch().is_empty() ): - batches.append(self.get_running_batch()) + req_groups.append(list(self.get_running_batch().reqs)) + # Chunked-resume reqs in active_reqs carry uncached tail + # (kv_committed_len - cache_protected_len, < page_size) that + # filter_batch just removed from last_batch but haven't been + # re-admitted to running_batch yet. The leak invariant must count it. + # C10: chunked-resume now lives in active_reqs (post-C4). + seen_ids = {id(req) for group in req_groups for req in group} + chunked_in_active = [ + req + for req in self.get_active_reqs().values() + if req.has_pending_chunk + and req.req_pool_idx is not None + and id(req) not in seen_ids + ] + if chunked_in_active: + req_groups.append(chunked_in_active) full_uncached = 0 swa_uncached = 0 - for batch in batches: - for req in batch.reqs: + for group in req_groups: + for req in group: assert req.kv_committed_freed == req.kv_overallocated_freed if req.kv_committed_freed or req.req_pool_idx is None: continue diff --git a/python/sglang/srt/managers/scheduler_components/load_inquirer.py b/python/sglang/srt/managers/scheduler_components/load_inquirer.py index 32acaa44ee7b..f5ef23144b80 100644 --- a/python/sglang/srt/managers/scheduler_components/load_inquirer.py +++ b/python/sglang/srt/managers/scheduler_components/load_inquirer.py @@ -44,7 +44,6 @@ class SchedulerLoadInquirer: get_running_batch: Callable get_waiting_queue: Callable get_stats: Callable - get_chunked_req: Callable get_disagg_prefill_bootstrap_queue: Callable get_disagg_prefill_inflight_queue: Callable get_disagg_decode_prealloc_queue: Callable @@ -66,11 +65,14 @@ def _get_num_pending_tokens(self, chunk_deduct: int = 0) -> int: time ``prefix_indices`` is already up-to-date, so the default 0 is correct. """ - num_pending_tokens = sum(req.seqlen for req in self.get_waiting_queue()) - if self.get_chunked_req() is not None: - req = self.get_chunked_req() - num_pending_tokens += req.seqlen - len(req.prefix_indices) - chunk_deduct - return num_pending_tokens + num_pending_tokens = sum( + req.seqlen - len(req.prefix_indices) for req in self.get_waiting_queue() + ) + # The chunked-resume req (if any) is now in waiting_queue, so it's + # already counted in the sum above. chunk_deduct subtracts the + # current chunk's extend that has been planned but not yet reflected + # in prefix_indices. + return num_pending_tokens - chunk_deduct def get_loads(self, req: GetLoadsReqInput = None) -> GetLoadsReqOutput: """ diff --git a/python/sglang/srt/managers/scheduler_components/pool_stats_observer.py b/python/sglang/srt/managers/scheduler_components/pool_stats_observer.py index a6ed752282bc..f01cee22814c 100644 --- a/python/sglang/srt/managers/scheduler_components/pool_stats_observer.py +++ b/python/sglang/srt/managers/scheduler_components/pool_stats_observer.py @@ -153,6 +153,7 @@ class SchedulerPoolStatsObserver: max_total_num_tokens: int get_last_batch: Callable get_running_batch: Callable + get_active_reqs: Callable def streaming_session_count(self) -> int: return sum( @@ -162,7 +163,8 @@ def streaming_session_count(self) -> int: ) def active_pool_idxs(self) -> set: - """Pool idxs currently owned by reqs in last_batch / running_batch. + """Pool idxs currently owned by reqs in last_batch / running_batch or + held by chunked-resume reqs in active_reqs. Used to decide which session slots' KV is owned by batch reqs (and thus counted via uncached_size, not session_held). @@ -174,6 +176,13 @@ def active_pool_idxs(self) -> set: for req in batch.reqs: if req.req_pool_idx is not None: idxs.add(req.req_pool_idx) + # Chunked-resume reqs in active_reqs still own their row across iters + # (filter_batch may have just moved them out of last_batch but they + # haven't yet been re-admitted to running_batch). + # C10: chunked-resume now lives in active_reqs (post-C4). + for req in self.get_active_reqs().values(): + if req.has_pending_chunk and req.req_pool_idx is not None: + idxs.add(req.req_pool_idx) return idxs def session_held_tokens(self) -> int: diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 350a2ec8b89c..bc74f548aa7b 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -85,10 +85,11 @@ def cache_finished_req(self, req: Req, is_insert: bool = True): self.token_to_kv_pool_allocator.free(kv_indices) def cache_unfinished_req(self, req: Req, chunked=False): + # Bound row read by kv_committed_len; see radix_cache.py for rationale. kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.fill_ids) + req.req_pool_idx, : req.kv_committed_len ] - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) def evict(self, params: EvictParams) -> EvictResult: diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index be209c6a99f3..ce9434ce1d74 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -602,17 +602,20 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: def cache_unfinished_req(self, req: Req, chunked=False) -> None: """Cache request when it is unfinished.""" + # Bound row read by kv_committed_len; see radix_cache.py for rationale. + assert req.kv_committed_len >= req.cache_protected_len + read_len = req.kv_committed_len def _skip_cache_unfinished_req(req: Req) -> None: kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.fill_ids) + req.req_pool_idx, :read_len ] - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) return - token_ids = req.fill_ids + token_ids = req.fill_ids[:read_len] cache_len = ( req.mamba_last_track_seqlen if self.enable_mamba_extra_buffer @@ -622,7 +625,7 @@ def _skip_cache_unfinished_req(req: Req) -> None: return _skip_cache_unfinished_req(req) kv_indices_orig = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :read_len ] # kv_indices is the kv indices to be cached kv_indices = kv_indices_orig[:cache_len] @@ -707,7 +710,7 @@ def _skip_cache_unfinished_req(req: Req) -> None: self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter # NOTE: this is needed for both page_size == 1 and page_size > 1 req.prefix_indices = torch.cat( [new_indices, kv_indices_orig[len(new_indices) :]] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0d520a5bafab..9c4560941783 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -171,16 +171,12 @@ def alloc(self, reqs: list[Req]) -> Optional[List[int]]: # Indices of reqs that already have a req_pool_idx and will reuse # their existing slot (e.g. chunked prefill continuing across chunks). reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None] - # NOTE: this check is relaxed temporarily - # https://github.com/sgl-project/sglang/pull/20476 - # if not any(r.is_dllm() for r in reqs): - # assert ( - # sum(1 for i in reusing if reqs[i].inflight_middle_chunks > 0) <= 1 - # ), "only one chunked request may reuse req_pool_idx in a batch" + # The row pool only cares whether the row has committed KV — it does + # not need to know whether the req is chunked. kv_committed_len > 0 + # naturally covers chunked-resume + DLLM staging + any reuse case. assert all( - reqs[i].inflight_middle_chunks > 0 or reqs[i].kv_committed_len > 0 - for i in reusing - ), "reusing request must be chunked or have committed KV" + reqs[i].kv_committed_len > 0 for i in reusing + ), "reusing request must have committed KV" need_size = len(reqs) - len(reusing) if need_size > len(self.free_slots): diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 6e35d1a313f2..5f028e33d086 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -490,10 +490,15 @@ def cache_unfinished_req(self, req: Req, chunked=False): if self.disable: return - token_ids = req.fill_ids - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) - ] + # Bound the row read by kv_committed_len (the actually-written prefix + # length on the row), not by len(fill_ids). They are equal in the + # common path, but init_next_round_input resets fill_ids to the full + # origin + output length while the row only holds KV up to + # kv_committed_len — reading beyond that yields garbage slot indices. + assert req.kv_committed_len >= req.cache_protected_len + read_len = req.kv_committed_len + token_ids = req.fill_ids[:read_len] + kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :read_len] radix_key = RadixKey( token_ids, req.extra_key, is_bigram=self.is_eagle @@ -539,7 +544,7 @@ def cache_unfinished_req(self, req: Req, chunked=False): self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter # - page_size != 1: there is a partial page at the end, keep the full kv_indices # - eagle case: bigram keys will only cache len - 1 kv indices if len(new_indices) < len(kv_indices): diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index 97bf02835881..ab2ff2152066 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -210,8 +210,10 @@ def cache_finished_req(self, req: Req, is_insert: bool = True): def cache_unfinished_req(self, req: Req, chunked=False): """Cache request when it is unfinished.""" assert req.req_pool_idx is not None - token_ids = req.fill_ids - prefill_len = len(token_ids) # prefill only (maybe chunked) + # Bound row read by kv_committed_len; see radix_cache.py for rationale. + assert req.kv_committed_len >= req.cache_protected_len + prefill_len = req.kv_committed_len + token_ids = req.fill_ids[:prefill_len] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, :prefill_len ].to(dtype=torch.int64, copy=True) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index dd76ad37df3d..4a6ce71dea05 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -486,19 +486,22 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: def cache_unfinished_req(self, req: Req, chunked=False) -> None: """Cache request when it is unfinished.""" + # Bound the row read by kv_committed_len, not len(fill_ids); see + # radix_cache.py:cache_unfinished_req for the rationale (SWA early- + # return + init_next_round_input leaves fill_ids longer than the row). + assert req.kv_committed_len >= req.cache_protected_len + read_len = req.kv_committed_len if self.disable: kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.fill_ids) + req.req_pool_idx, :read_len ] - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter req.prefix_indices = kv_indices return - token_ids = req.fill_ids - kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) - ] + token_ids = req.fill_ids[:read_len] + kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :read_len] radix_key = RadixKey( token_ids, req.extra_key, is_bigram=self.is_eagle @@ -542,7 +545,7 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: result = self.inc_lock_ref(new_last_node) swa_uuid_for_lock = result.swa_uuid_for_lock - # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + # `req.prefix_indices` will be used by add_one_req reuse branch next iter if len(new_indices) < len(kv_indices): req.prefix_indices = torch.cat( [new_indices, kv_indices[len(new_indices) :]] diff --git a/python/sglang/srt/mem_cache/unified_radix_cache.py b/python/sglang/srt/mem_cache/unified_radix_cache.py index fa298b053b61..1a34e4d9f75d 100644 --- a/python/sglang/srt/mem_cache/unified_radix_cache.py +++ b/python/sglang/srt/mem_cache/unified_radix_cache.py @@ -510,17 +510,20 @@ def cache_unfinished_req(self, req: Req, chunked=False, **kwargs) -> None: if self.session.try_cache_unfinished_req(req, chunked=chunked, **kwargs): return - token_ids = req.fill_ids + # Bound row read by kv_committed_len; see radix_cache.py for rationale. + assert req.kv_committed_len >= req.cache_protected_len + read_len = req.kv_committed_len + token_ids = req.fill_ids[:read_len] if self.disable: kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :read_len ] req.prefix_indices = kv_indices return kv_indices_orig = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :read_len ] # components prepare insert data + return effective cache_len diff --git a/python/sglang/srt/multiplex/multiplexing_mixin.py b/python/sglang/srt/multiplex/multiplexing_mixin.py index d5fd32a334f5..dc19c6624a34 100644 --- a/python/sglang/srt/multiplex/multiplexing_mixin.py +++ b/python/sglang/srt/multiplex/multiplexing_mixin.py @@ -208,10 +208,19 @@ def event_loop_pdmux(self: Scheduler): self.process_batch_result( self.split_prefill_batch, prefill_result ) - if self.running_batch and not self.running_batch.is_empty(): - self.running_batch.merge_batch(self.split_prefill_batch) - else: - self.running_batch = self.split_prefill_batch + # Drop chunked-resume reqs before folding split_prefill_batch + # into running_batch. running_batch runs decode forward and + # admitting a mid-prefill req there breaks shape + KV + # accounting; the dropped reqs persist in self.waiting_queue + # (retention in get_new_batch_prefill) and re-enter via the + # next iter's Stage A stash + admission cycle. Mirrors the + # standard event_loop path at scheduler.py:2514. + self.split_prefill_batch.filter_batch(exclude_chunked_req=True) + if not self.split_prefill_batch.is_empty(): + if self.running_batch and not self.running_batch.is_empty(): + self.running_batch.merge_batch(self.split_prefill_batch) + else: + self.running_batch = self.split_prefill_batch self.split_prefill_batch = None wait_prefill_kernel_done = False diff --git a/python/sglang/srt/observability/forward_pass_metrics.py b/python/sglang/srt/observability/forward_pass_metrics.py index e271bd6de43d..00891b308172 100644 --- a/python/sglang/srt/observability/forward_pass_metrics.py +++ b/python/sglang/srt/observability/forward_pass_metrics.py @@ -10,7 +10,7 @@ Data flow:: Scheduler process: - SchedulerMetricsMixin._emit_forward_pass_metrics() + SchedulerMetricsReporter._emit_forward_pass_metrics() -> _FpmPublisherThread -> ZMQ PUB (localhost) External consumer: diff --git a/python/sglang/srt/session/streaming_session.py b/python/sglang/srt/session/streaming_session.py index 17602c3b3b6b..d34443f93bb4 100644 --- a/python/sglang/srt/session/streaming_session.py +++ b/python/sglang/srt/session/streaming_session.py @@ -331,8 +331,15 @@ def try_cache_unfinished_req( if not _is_streaming(req): return False if chunked: + # Bound row read by kv_committed_len, NOT len(fill_ids): after + # a SWA early-return the next iter's init_next_round_input + # restores fill_ids to origin+output (full length), but the + # row only holds KV up to kv_committed_len — reading beyond + # that yields garbage slot indices. See radix_cache.py for + # the same fix applied to the non-session caches. + assert req.kv_committed_len >= req.cache_protected_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.fill_ids) + req.req_pool_idx, : req.kv_committed_len ] req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) return True diff --git a/test/registered/unit/managers/test_hisparse_unit.py b/test/registered/unit/managers/test_hisparse_unit.py index 22996c54e6e5..e1ac4e3ab1d9 100644 --- a/test/registered/unit/managers/test_hisparse_unit.py +++ b/test/registered/unit/managers/test_hisparse_unit.py @@ -52,7 +52,7 @@ def _make_req(rid="test-req-0", origin_input_ids=None, output_ids=None): finished_reason=None, hisparse_staging=False, staging=False, - inflight_middle_chunks=0, + pending_middle_outputs=0, ) req.finished = lambda: req.finished_reason is not None req.set_extend_input_len = lambda extend_input_len: setattr( diff --git a/test/registered/unit/managers/test_prefill_adder.py b/test/registered/unit/managers/test_prefill_adder.py index c5c51ee063c6..e4a870c50ff7 100644 --- a/test/registered/unit/managers/test_prefill_adder.py +++ b/test/registered/unit/managers/test_prefill_adder.py @@ -77,6 +77,11 @@ def create_mock_req(self, rid, priority, max_new_tokens, output_len=0, wait_time req.sampling_params = SimpleNamespace(max_new_tokens=max_new_tokens) req.time_stats = SimpleNamespace(wait_queue_entry_time=wait_time) req.finished.return_value = False + # v2 add_one_req reads these on the reuse-branch gate; MagicMock(spec=Req) + # doesn't surface attributes set only in Req.__init__, so seed them. + req.has_pending_chunk = False + req.is_dllm.return_value = False + req.host_hit_length = 0 return req def create_adder(self, running_batch, **kwargs): @@ -383,9 +388,7 @@ def test_mixed_chunk_prefill_budgets(self): req1.last_node = MagicMock() req1.sampling_params.ignore_eos = False - result1 = adder.add_one_req( - req1, has_chunked_req=False, truncation_align_size=None - ) + result1 = adder.add_one_req(req1, truncation_align_size=None) self.assertEqual(len(adder.can_run_list), 1) self.assertEqual(adder.rem_chunk_tokens, 0) # 56 - 56 @@ -417,9 +420,7 @@ def test_mixed_chunk_prefill_budgets(self): req2.last_node = MagicMock() req2.sampling_params.ignore_eos = False - result2 = adder2.add_one_req( - req2, has_chunked_req=False, truncation_align_size=None - ) + result2 = adder2.add_one_req(req2, truncation_align_size=None) self.assertEqual(len(adder2.can_run_list), 1) self.assertEqual(adder2.rem_chunk_tokens, 3) # 59 - 56 = 3 remaining @@ -434,78 +435,12 @@ def test_mixed_chunk_prefill_budgets(self): req3.last_node = MagicMock() req3.sampling_params.ignore_eos = False - result3 = adder2.add_one_req( - req3, has_chunked_req=False, truncation_align_size=None - ) + result3 = adder2.add_one_req(req3, truncation_align_size=None) self.assertEqual(len(adder2.can_run_list), 2) self.assertEqual(adder2.rem_chunk_tokens, 0) # 3 - 3 = 0 self.assertEqual(result3, AddReqResult.OTHER) - def _build_hybrid_swa_chunked_req( - self, - *, - page_size, - rem_swa, - rem_chunk=2048, - extend_input_len=500, - is_hybrid_swa=True, - full_available=100_000, - ): - self.mock_token_allocator.swa_available_size.return_value = rem_swa - self.mock_token_allocator.full_available_size.return_value = full_available - self.mock_token_allocator.available_size.return_value = full_available - self.mock_tree_cache.sliding_window_size = 128 - adder = self.create_adder( - self.create_running_batch(), - page_size=page_size, - rem_chunk_tokens=rem_chunk, - ) - adder.is_hybrid_swa = is_hybrid_swa - - req = self.create_mock_req("chunked", priority=0, max_new_tokens=128) - req.extend_input_len = extend_input_len - req.prefix_indices = [] - req.fill_ids = list(range(extend_input_len)) - req.set_extend_input_len = MagicMock() - return adder, req - - def test_add_chunked_req_hybrid_swa_reserves_page_for_alloc_extend(self): - # alloc_extend needs extend_num_tokens + page_size per request. If the - # scheduler hands out all of rem_swa_tokens, alloc_extend cannot get its - # extra page and OOMs. With the fix, extend_input_len must cap at - # rem_swa_tokens - page_size so the page is reserved. - PAGE_SIZE = 64 - REM_SWA = 100 - adder, req = self._build_hybrid_swa_chunked_req( - page_size=PAGE_SIZE, rem_swa=REM_SWA - ) - - result = adder.add_chunked_req(req) - - self.assertIs(result, req) # truncated → chunked prefill continues - req.set_extend_input_len.assert_called_once() - new_len = req.set_extend_input_len.call_args.args[0] - self.assertLessEqual(new_len + PAGE_SIZE, REM_SWA) - self.assertEqual(new_len, REM_SWA - PAGE_SIZE) - - def test_add_chunked_req_hybrid_swa_defers_when_swa_below_page(self): - # When rem_swa_tokens <= page_size there is no room to serve even the - # reservation, so the chunked req must be deferred (returned unchanged) - # instead of falling back to rem_chunk_tokens and bypassing SWA budget. - PAGE_SIZE = 64 - adder, req = self._build_hybrid_swa_chunked_req( - page_size=PAGE_SIZE, rem_swa=PAGE_SIZE - ) - original_len = req.extend_input_len - - result = adder.add_chunked_req(req) - - self.assertIs(result, req) - req.set_extend_input_len.assert_not_called() - self.assertEqual(req.extend_input_len, original_len) - self.assertEqual(len(adder.can_run_list), 0) - def test_swa_budget_for_req(self): cases = [ # (extend, rem_chunk, window, page, expected, label) @@ -526,24 +461,6 @@ def test_swa_budget_for_req(self): ) self.assertEqual(adder._swa_budget_for_req(extend), expected) - def test_add_chunked_req_non_hybrid_no_swa_reservation(self): - # Non-hybrid path: the SWA-pool reservation must NOT apply, otherwise - # the fix would regress non-SWA models. - PAGE_SIZE = 16 - adder, req = self._build_hybrid_swa_chunked_req( - page_size=PAGE_SIZE, - rem_swa=10, - rem_chunk=500, - extend_input_len=200, - is_hybrid_swa=False, - full_available=300, - ) - - result = adder.add_chunked_req(req) - self.assertIsNone(result) - req.set_extend_input_len.assert_called_once_with(200) - self.assertIn(req, adder.can_run_list) - if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/managers/test_scheduler_chunked_req_gate.py b/test/registered/unit/managers/test_scheduler_chunked_req_gate.py deleted file mode 100644 index 0263170bc839..000000000000 --- a/test/registered/unit/managers/test_scheduler_chunked_req_gate.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Regression tests for the SWA chunked-req stash gate (#24252).""" - -import unittest -from array import array -from types import SimpleNamespace -from unittest.mock import MagicMock - -import torch - -from sglang.test.ci.ci_register import register_cpu_ci -from sglang.test.test_utils import CustomTestCase, maybe_stub_sgl_kernel - -maybe_stub_sgl_kernel() - -from sglang.srt.managers.schedule_batch import Req -from sglang.srt.managers.scheduler import Scheduler -from sglang.srt.mem_cache.chunk_cache import ChunkCache - -register_cpu_ci(est_time=6, suite="base-a-test-cpu") - - -def _make_req( - *, - req_pool_idx: int, - fill_ids: list, - prefix_indices: torch.Tensor, - extend_input_len: int, -) -> Req: - req = Req.__new__(Req) - req.rid = "test-req" - req.origin_input_ids = array("q", fill_ids) - req.output_ids = array("q") - req.fill_ids = array("q", fill_ids) - req.prefix_indices = prefix_indices - req.req_pool_idx = req_pool_idx - req.extend_input_len = extend_input_len - req.inflight_middle_chunks = 0 - req.host_hit_length = 0 - req.cache_protected_len = 0 - req.skip_radix_cache_insert = False - req.last_node = None - req.swa_uuid_for_lock = None - req.session = None - req.return_logprob = False - req.logprob_start_len = -1 - req.positional_embed_overrides = None - req.extra_key = None - req.mamba_pool_idx = None - req.sampling_params = SimpleNamespace(max_new_tokens=128, ignore_eos=False) - return req - - -def _make_req_to_token_pool(num_slots: int, max_context: int) -> SimpleNamespace: - # Slot s contains a recognizable fingerprint [s*1000, s*1000+1, ...] - # so we can tell a corrupted prefix_indices from a healthy one by content. - pool = SimpleNamespace() - pool.req_to_token = ( - torch.arange(max_context, dtype=torch.int32).unsqueeze(0).repeat(num_slots, 1) - + torch.arange(num_slots, dtype=torch.int32).unsqueeze(1) * 1000 - ) - return pool - - -def _make_chunk_cache(req_to_token_pool) -> ChunkCache: - return ChunkCache( - SimpleNamespace( - req_to_token_pool=req_to_token_pool, - token_to_kv_pool_allocator=None, - page_size=1, - ) - ) - - -def _scheduler_for_get_next_batch(*, tree_cache, chunked_req) -> Scheduler: - s = Scheduler.__new__(Scheduler) - s._abort_on_waiting_timeout = MagicMock() - s._abort_on_running_timeout = MagicMock() - s.dllm_config = None - s.dllm_manager = None - s.enable_hisparse = False - s.enable_fpm = False - s.last_batch = None - s.require_mlp_sync = False - s.spec_algorithm = MagicMock() - s.server_args = MagicMock(speculative_skip_dp_mlp_sync=True) - s.running_batch = MagicMock() - s.running_batch.is_empty.return_value = True - s.running_batch.is_prefill_only = False - s.running_batch.batch_is_full = False - s.running_batch.reqs = [] - s.get_new_batch_prefill = MagicMock(return_value=None) - s.dp_attn_adapter = MagicMock() - s.dp_attn_adapter.maybe_prepare_mlp_sync_batch = MagicMock( - side_effect=lambda batch, **_: batch - ) - s._maybe_prepare_ngram_embedding = MagicMock(side_effect=lambda batch: batch) - s.update_running_batch = MagicMock(side_effect=lambda batch: batch) - s.tree_cache = tree_cache - s.chunked_req = chunked_req - return s - - -class TestStashGatePreservesPrefixIndices(CustomTestCase): - """Consumer side: real ChunkCache.cache_unfinished_req mutates - req.prefix_indices iff stash actually runs, so prefix_indices content - is the bug-detection signal.""" - - POOL_IDX = 4 - INITIAL_PREFIX_LEN = 8 # what was really cached last iter - POST_RESET_FILL_LEN = 32 # length after init_next_round_input - NUM_SLOTS = 8 - MAX_CONTEXT = 64 - - def _build(self, flag: bool): - pool = _make_req_to_token_pool(self.NUM_SLOTS, self.MAX_CONTEXT) - cache = _make_chunk_cache(pool) - initial_prefix = pool.req_to_token[self.POOL_IDX, : self.INITIAL_PREFIX_LEN].to( - dtype=torch.int64, copy=True - ) - req = _make_req( - req_pool_idx=self.POOL_IDX, - fill_ids=list(range(self.POST_RESET_FILL_LEN)), - prefix_indices=initial_prefix, - extend_input_len=0, - ) - s = _scheduler_for_get_next_batch(tree_cache=cache, chunked_req=req) - s._chunked_req_scheduled_last_iter = flag - return s, req, initial_prefix, pool - - def test_deferred_chunked_req_keeps_real_prefix_indices(self): - # The bug case: a spurious stash on a deferred chunked_req - # would extend prefix_indices to len(fill_ids). - s, req, initial_prefix, _ = self._build(flag=False) - - Scheduler.get_next_batch_to_run(s) - - self.assertEqual(req.prefix_indices.shape[0], self.INITIAL_PREFIX_LEN) - self.assertTrue(torch.equal(req.prefix_indices, initial_prefix)) - - def test_scheduled_chunked_req_advances_prefix_indices_via_real_stash(self): - # Symmetric guard against over-gating: when the chunked_req was - # actually scheduled, stash must run and advance prefix_indices. - s, req, _, pool = self._build(flag=True) - - Scheduler.get_next_batch_to_run(s) - - expected = pool.req_to_token[self.POOL_IDX, : self.POST_RESET_FILL_LEN].to( - dtype=torch.int64 - ) - self.assertEqual(req.prefix_indices.shape[0], self.POST_RESET_FILL_LEN) - self.assertTrue(torch.equal(req.prefix_indices, expected)) - - def test_no_chunked_req_never_mutates_state_even_with_stale_flag(self): - # Retract path clears chunked_req without resetting the flag; - # the outer `if chunked_req is not None` guard must hold. - pool = _make_req_to_token_pool(self.NUM_SLOTS, self.MAX_CONTEXT) - cache = _make_chunk_cache(pool) - s = _scheduler_for_get_next_batch(tree_cache=cache, chunked_req=None) - s._chunked_req_scheduled_last_iter = True - - Scheduler.get_next_batch_to_run(s) - self.assertIsNone(s.chunked_req) - - -if __name__ == "__main__": - unittest.main()