diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 80d29e9bbab9..efa979460f63 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -170,7 +170,6 @@ def process_prebuilt( hidden_states=hidden_states, verified_id=self.output_ids, new_seq_lens=self.seq_lens, - allocate_lens=self.seq_lens, ) spec_info.prepare_for_extend(self) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 64c3eb1de01e..6a7f44a310e4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1760,7 +1760,7 @@ def filter_batch( def merge_batch(self, other: "ScheduleBatch"): # NOTE: in v2 eagle mode, we do not need wait verify here because - # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future + # 1) current batch is always prefill, whose seq_lens is not a future # 2) other batch is always decode, which is finished in previous step # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 899af5184915..cc5d2fed645b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2074,8 +2074,6 @@ def run_batch( # batch.spec_info = EagleDraftInput( # future_indices=future_indices, # verify_done=batch_result.next_draft_input.verify_done, - # # FIXME(lsyin): remove the allocate_lens in EagleDraftInput - # allocate_lens=batch_result.next_draft_input.allocate_lens, # ) # The future value, usually for next batch preparation diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 515ec8195e42..e3d8f9668f38 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -263,7 +263,6 @@ def _resolve_spec_overlap_token_ids( """Resolve the padding next token ids for speculative decoding with overlap.""" assert result.next_token_ids.is_cpu assert result.accept_lens.is_cpu - assert result.allocate_lens.is_cpu next_token_ids = result.next_token_ids.tolist() accept_lens = result.accept_lens.tolist() @@ -271,7 +270,9 @@ def _resolve_spec_overlap_token_ids( predict_tokens = [] stride = self.draft_worker.speculative_num_draft_tokens + for i, req in enumerate(batch.reqs): + req.kv_committed_len += accept_lens[i] predict_tokens.append( next_token_ids[i * stride : i * stride + accept_lens[i]] ) @@ -300,8 +301,6 @@ def process_batch_result_decode( next_token_logprobs = logits_output.next_token_logprobs.tolist() elif batch.is_v2_eagle: next_token_ids = self._resolve_spec_overlap_token_ids(result, batch) - allocate_lens_list = result.allocate_lens.tolist() - accept_lens_list = result.accept_lens.tolist() self.num_generated_tokens += len(batch.reqs) if not batch.spec_algorithm.is_none(): diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 6e753f1659c4..640994e70450 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -39,7 +39,6 @@ class GenerationBatchResult: # FIXME(lsyin): maybe move to a better place? # sync path: forward stream -> output processor accept_lens: Optional[torch.Tensor] = None - allocate_lens: Optional[torch.Tensor] = None # relay path: forward stream -> next step forward next_draft_input: Optional[EagleDraftInput] = None @@ -67,9 +66,6 @@ def copy_to_cpu(self, return_logprob: bool = False): if self.accept_lens is not None: self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) - if self.allocate_lens is not None: - self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True) - self.copy_done.record() @classmethod diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 084ccdb1220f..35f30a016745 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -624,7 +624,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): # Inputs for V2 overlap worker future_indices: Optional[FutureIndices] = None - allocate_lens: Optional[torch.Tensor] = None new_seq_lens: Optional[torch.Tensor] = None verify_done: Optional[torch.cuda.Event] = None @@ -665,7 +664,6 @@ def create_idle_input( topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, - allocate_lens=torch.empty((0,), device=device, dtype=torch.int32), new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), accept_length=torch.empty((0,), device=device, dtype=torch.int32), accept_length_cpu=[], @@ -738,7 +736,6 @@ def generate_attn_arg_prefill( def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): if self.future_indices is not None: self.future_indices.indices = self.future_indices.indices[new_indices] - self.allocate_lens = self.allocate_lens[new_indices] return if has_been_filtered: @@ -767,9 +764,6 @@ def merge_batch(self, spec_info: "EagleDraftInput"): [self.future_indices.indices, spec_info.future_indices.indices] ) ) - self.allocate_lens = torch.cat( - [self.allocate_lens, spec_info.allocate_lens] - ) return if self.hidden_states is None: diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index c6a763942561..7607af9acbee 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -84,55 +84,57 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): bs = batch.batch_size() - # TODO(lsyin): implement over-allocation - # Now seq_lens and allocate_lens are correct + # Now seq_lens is correct batch.maybe_wait_verify_done() page_size = batch.token_to_kv_pool_allocator.page_size + cur_kv_lens_cpu = [] + nxt_kv_lens_cpu = [] + num_needed_tokens = 0 + for r in batch.reqs: + # Over-allocation happens here + x = r.kv_committed_len + 2 * self.ALLOC_LEN_PER_DECODE - r.kv_allocated_len + cur_kv_lens_cpu.append(r.kv_allocated_len) + nxt_kv_lens_cpu.append(r.kv_allocated_len + x) + num_needed_tokens += x + r.kv_allocated_len += x + + cur_kv_lens_cpu = torch.tensor(cur_kv_lens_cpu, dtype=torch.int32, device="cpu") + nxt_kv_lens_cpu = torch.tensor(nxt_kv_lens_cpu, dtype=torch.int32, device="cpu") if page_size == 1: - new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE - num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item() out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens) else: + cur_kv_lens = cur_kv_lens_cpu.to(device=batch.device) + nxt_kv_lens = nxt_kv_lens_cpu.to(device=batch.device) last_loc = get_last_loc( batch.req_to_token_pool.req_to_token, batch.req_pool_indices, - self.allocate_lens, + cur_kv_lens, ) - new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE - new_allocate_lens_cpu = new_allocate_lens.cpu() - allocate_lens_cpu = self.allocate_lens.cpu() - extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item() out_cache_loc = alloc_paged_token_slots_extend( batch.tree_cache, - self.allocate_lens, - allocate_lens_cpu, - new_allocate_lens, - new_allocate_lens_cpu, + cur_kv_lens, + cur_kv_lens_cpu, + nxt_kv_lens, + nxt_kv_lens_cpu, last_loc, - extend_num_tokens, + num_needed_tokens, ) assign_req_to_token_pool_func( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, - self.allocate_lens, - new_allocate_lens, + cur_kv_lens_cpu.to(device=batch.device), + nxt_kv_lens_cpu.to(device=batch.device), out_cache_loc, bs, ) - self.allocate_lens = new_allocate_lens - # FIXME(lsyin): make this sync optional batch.seq_lens_cpu = batch.seq_lens.cpu() batch.seq_lens_sum = batch.seq_lens_cpu.sum().item() - for i, req in enumerate(batch.reqs): - req.kv_committed_len = batch.seq_lens_cpu[i].item() - req.kv_allocated_len = req.kv_committed_len + self.ALLOC_LEN_PER_DECODE - def prepare_for_v2_draft( self: EagleDraftInput, req_to_token_pool: ReqToTokenPool, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 7ee81f6f06b8..5ae77b990480 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -447,7 +447,6 @@ def _draft_extend_for_prefill( hidden_states=target_hidden_states, verified_id=next_token_ids, new_seq_lens=batch.seq_lens, - allocate_lens=batch.seq_lens, # draft mode is same with decode mode, only 1 num token per batch num_tokens_per_batch=1, num_tokens_for_logprob_per_batch=1, @@ -620,19 +619,14 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) - draft_input: EagleDraftInput = model_worker_batch.spec_info verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch) assert verify_input.is_verify_input() model_worker_batch.spec_info = verify_input - batch_output = self.verify(model_worker_batch, draft_input.allocate_lens) + batch_output = self.verify(model_worker_batch) self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output) return batch_output - def verify( - self, - batch: ModelWorkerBatch, - cur_allocate_lens: torch.Tensor, - ): + def verify(self, batch: ModelWorkerBatch): # Since batch.seq_lens is allocated in another stream, we need # record_stream() to prevent pytorch gc and reuse the gpu memory # while forward_stream is still running. @@ -710,7 +704,6 @@ def verify( next_draft_input = EagleDraftInput( verified_id=verified_id, new_seq_lens=new_seq_lens, - allocate_lens=cur_allocate_lens, verify_done=verify_done, ) @@ -720,7 +713,6 @@ def verify( can_run_cuda_graph=can_run_cuda_graph, next_draft_input=next_draft_input, accept_lens=accept_length, - allocate_lens=cur_allocate_lens, ) def move_accepted_tokens_to_target_kvcache(