Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def process_prebuilt(
hidden_states=hidden_states,
verified_id=self.output_ids,
new_seq_lens=self.seq_lens,
allocate_lens=self.seq_lens,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ def filter_batch(

def merge_batch(self, other: "ScheduleBatch"):
# NOTE: in v2 eagle mode, we do not need wait verify here because
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
# 1) current batch is always prefill, whose seq_lens is not a future
# 2) other batch is always decode, which is finished in previous step

# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,8 +2074,6 @@ def run_batch(
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
# allocate_lens=batch_result.next_draft_input.allocate_lens,
# )

# The future value, usually for next batch preparation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,16 @@ def _resolve_spec_overlap_token_ids(
"""Resolve the padding next token ids for speculative decoding with overlap."""
assert result.next_token_ids.is_cpu
assert result.accept_lens.is_cpu
assert result.allocate_lens.is_cpu

next_token_ids = result.next_token_ids.tolist()
accept_lens = result.accept_lens.tolist()
result.num_accepted_tokens = sum(accept_lens) - len(batch.reqs)

predict_tokens = []
stride = self.draft_worker.speculative_num_draft_tokens

for i, req in enumerate(batch.reqs):
req.kv_committed_len += accept_lens[i]
predict_tokens.append(
next_token_ids[i * stride : i * stride + accept_lens[i]]
)
Expand Down Expand Up @@ -300,8 +301,6 @@ def process_batch_result_decode(
next_token_logprobs = logits_output.next_token_logprobs.tolist()
elif batch.is_v2_eagle:
next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
allocate_lens_list = result.allocate_lens.tolist()
accept_lens_list = result.accept_lens.tolist()

self.num_generated_tokens += len(batch.reqs)
if not batch.spec_algorithm.is_none():
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/managers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class GenerationBatchResult:
# FIXME(lsyin): maybe move to a better place?
# sync path: forward stream -> output processor
accept_lens: Optional[torch.Tensor] = None
allocate_lens: Optional[torch.Tensor] = None

# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None
Expand Down Expand Up @@ -67,9 +66,6 @@ def copy_to_cpu(self, return_logprob: bool = False):
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)

if self.allocate_lens is not None:
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)

self.copy_done.record()

@classmethod
Expand Down
6 changes: 0 additions & 6 deletions python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):

# Inputs for V2 overlap worker
future_indices: Optional[FutureIndices] = None
allocate_lens: Optional[torch.Tensor] = None
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None

Expand Down Expand Up @@ -665,7 +664,6 @@ def create_idle_input(
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
allocate_lens=torch.empty((0,), device=device, dtype=torch.int32),
new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32),
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
Expand Down Expand Up @@ -738,7 +736,6 @@ def generate_attn_arg_prefill(
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if self.future_indices is not None:
self.future_indices.indices = self.future_indices.indices[new_indices]
self.allocate_lens = self.allocate_lens[new_indices]
return

if has_been_filtered:
Expand Down Expand Up @@ -767,9 +764,6 @@ def merge_batch(self, spec_info: "EagleDraftInput"):
[self.future_indices.indices, spec_info.future_indices.indices]
)
)
self.allocate_lens = torch.cat(
[self.allocate_lens, spec_info.allocate_lens]
)
return

if self.hidden_states is None:
Expand Down
46 changes: 24 additions & 22 deletions python/sglang/srt/speculative/eagle_info_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,55 +84,57 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):

bs = batch.batch_size()

# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
# Now seq_lens is correct
batch.maybe_wait_verify_done()

page_size = batch.token_to_kv_pool_allocator.page_size
cur_kv_lens_cpu = []
nxt_kv_lens_cpu = []
num_needed_tokens = 0
for r in batch.reqs:
# Over-allocation happens here
x = r.kv_committed_len + 2 * self.ALLOC_LEN_PER_DECODE - r.kv_allocated_len
cur_kv_lens_cpu.append(r.kv_allocated_len)
nxt_kv_lens_cpu.append(r.kv_allocated_len + x)
num_needed_tokens += x
r.kv_allocated_len += x

cur_kv_lens_cpu = torch.tensor(cur_kv_lens_cpu, dtype=torch.int32, device="cpu")
nxt_kv_lens_cpu = torch.tensor(nxt_kv_lens_cpu, dtype=torch.int32, device="cpu")

if page_size == 1:
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
else:
cur_kv_lens = cur_kv_lens_cpu.to(device=batch.device)
nxt_kv_lens = nxt_kv_lens_cpu.to(device=batch.device)
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
self.allocate_lens,
cur_kv_lens,
)
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
new_allocate_lens_cpu = new_allocate_lens.cpu()
allocate_lens_cpu = self.allocate_lens.cpu()
extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item()
out_cache_loc = alloc_paged_token_slots_extend(
batch.tree_cache,
self.allocate_lens,
allocate_lens_cpu,
new_allocate_lens,
new_allocate_lens_cpu,
cur_kv_lens,
cur_kv_lens_cpu,
nxt_kv_lens,
nxt_kv_lens_cpu,
last_loc,
extend_num_tokens,
num_needed_tokens,
)

assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
cur_kv_lens_cpu.to(device=batch.device),
nxt_kv_lens_cpu.to(device=batch.device),
out_cache_loc,
bs,
)

self.allocate_lens = new_allocate_lens

# FIXME(lsyin): make this sync optional
batch.seq_lens_cpu = batch.seq_lens.cpu()
batch.seq_lens_sum = batch.seq_lens_cpu.sum().item()

for i, req in enumerate(batch.reqs):
req.kv_committed_len = batch.seq_lens_cpu[i].item()
req.kv_allocated_len = req.kv_committed_len + self.ALLOC_LEN_PER_DECODE

def prepare_for_v2_draft(
self: EagleDraftInput,
req_to_token_pool: ReqToTokenPool,
Expand Down
12 changes: 2 additions & 10 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def _draft_extend_for_prefill(
hidden_states=target_hidden_states,
verified_id=next_token_ids,
new_seq_lens=batch.seq_lens,
allocate_lens=batch.seq_lens,
# draft mode is same with decode mode, only 1 num token per batch
num_tokens_per_batch=1,
num_tokens_for_logprob_per_batch=1,
Expand Down Expand Up @@ -620,19 +619,14 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
draft_input: EagleDraftInput = model_worker_batch.spec_info
verify_input: EagleVerifyInput = self.draft_worker.draft(model_worker_batch)
assert verify_input.is_verify_input()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
batch_output = self.verify(model_worker_batch)
self.draft_worker._draft_extend_for_decode(model_worker_batch, batch_output)
return batch_output

def verify(
self,
batch: ModelWorkerBatch,
cur_allocate_lens: torch.Tensor,
):
def verify(self, batch: ModelWorkerBatch):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
Expand Down Expand Up @@ -710,7 +704,6 @@ def verify(
next_draft_input = EagleDraftInput(
verified_id=verified_id,
new_seq_lens=new_seq_lens,
allocate_lens=cur_allocate_lens,
verify_done=verify_done,
)

Expand All @@ -720,7 +713,6 @@ def verify(
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
accept_lens=accept_length,
allocate_lens=cur_allocate_lens,
)

def move_accepted_tokens_to_target_kvcache(
Expand Down
Loading