Skip to content

[Eagle] Refactor eagle speculative decoding#3986

Merged
zhyncs merged 8 commits intomainfrom
ying-eagle
Mar 5, 2025
Merged

[Eagle] Refactor eagle speculative decoding#3986
zhyncs merged 8 commits intomainfrom
ying-eagle

Conversation

@Ying1123
Copy link
Copy Markdown
Contributor

@Ying1123 Ying1123 commented Mar 2, 2025

Prefix caching and chunked prefill will be compatible with eagle speculative decoding after this PR.

Co-authored-by: SangBin Cho rkooo567@gmail.com
Co-authored-by: Sehoon Kim kssteven418@gmail.com
Co-authored-by: Lianmin Zheng lianminzheng@gmail.com

@Ying1123 Ying1123 marked this pull request as draft March 2, 2025 02:56
@Ying1123 Ying1123 force-pushed the ying-eagle branch 5 times, most recently from e80b2de to 9c28d33 Compare March 2, 2025 10:20
@Ying1123 Ying1123 marked this pull request as ready for review March 3, 2025 01:50
@Ying1123 Ying1123 requested a review from HaiShaw as a code owner March 3, 2025 01:50
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Mar 3, 2025

@Ying1123 May you help resolve the conflicts? Thanks!

@Ying1123 Ying1123 force-pushed the ying-eagle branch 3 times, most recently from 5e9c58a to 172c25f Compare March 3, 2025 05:46
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think style-wise this is a bit confusing, if we define the allocator to be in charge of the agnostic memory operations and define another memory pool class for the underlying layouts, we should be using allocators consistently in scheduler and only use memory pool at lower level codes.

Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re @xiezhq-hermann: Let us merge this first to reduce the code divergence. Feel free to refactor it later with a better design.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about token_to_kv_indices_pool 😂

@mpjlu
Copy link
Copy Markdown
Contributor

mpjlu commented Mar 3, 2025

commit a574770: there is illegal bug when run DeepSeek: 409 File "/data/peng/sglang/python/sglang/srt/managers/scheduler.py", line 1218, in run_batch 410 ) = self.draft_worker.forward_batch_speculative_generation(batch) 411 File "/data/peng/sglang/python/sglang/srt/speculative/eagle_worker.py", line 189, in forward_batch_speculative_generation 412 spec_info, to_free_cache_loc = self.draft(batch) 413 File "/data/peng/sglang/python/sglang/srt/speculative/eagle_worker.py", line 244, in draft 414 assign_draft_cache_locs[(num_seqs,)]( 415 File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in 416 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) 417 File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run 418 kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, 419 File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in call 420 self.launch(*args, **kwargs) 421 RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@mpjlu Thanks for reporting. Could you provide the reproducible command? @ispobock @zhyncs Could you also help take a look?
The following command can reproduce:
python3 -m sglang.launch_server
--model-path $model_path
--tp $tp_size
--dist-init-addr 29.224.56.106:5000
--nnodes 2
--node-rank 0
--trust-remote-code
--mem-fraction-static 0.6
--max-running-requests 64
--speculative-draft-model-path $draft_path
--speculative-algorithm NEXTN
--speculative-num-steps 2 \
--speculative-eagle-topk 2 \
--speculative-num-draft-tokens 4 \
--disable-cuda-graph \

test.py

import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
import time

Chat completion

start = time.time()
for i in range(100):
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": "请以诚信为主题写一篇1000字作文?"},
], temperature=0.6, max_tokens=1000,
extra_body={"top_p": 0.6, "top_k": 50}
)
print(response)
print("dur=", time.time() - start)

@zhyncs zhyncs mentioned this pull request Mar 3, 2025
12 tasks
@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 4, 2025

I verified e19e733 on 8*H200 for DeepSeek-V3 model with nextn enabled and it works fine.

@mpjlu I cannot reproduce your error in my environment. I am not sure if it's an error for multi-node setting since I tried the same args as your command but on one node.

@mpjlu
Copy link
Copy Markdown
Contributor

mpjlu commented Mar 4, 2025

I verified e19e733 on 8*H200 for DeepSeek-V3 model with nextn enabled and it works fine.

@mpjlu I cannot reproduce your error in my environment. I am not sure if it's an error for multi-node setting since I tried the same args as your command but on one node.

Thanks very much.
We also can run with 8H20, but cannot run with 16H20 with TP 16.

self.disable_radix_cache = True
self.chunked_prefill_size = -1
if self.max_running_requests is None:
self.max_running_requests = 32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This setting may affect throughput, especially for throughput oriented model like DeepSeek. I tried request rate 16 on ShareGPT datasets, the TTFT is higher and throughput is lower with this limit. We may need to figure out a solution to enable for larger batch sizes.

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 4, 2025

I verified e19e733 on 8*H200 for DeepSeek-V3 model with nextn enabled and it works fine.
@mpjlu I cannot reproduce your error in my environment. I am not sure if it's an error for multi-node setting since I tried the same args as your command but on one node.

Thanks very much. We also can run with 8_H20, but cannot run with 16_H20 with TP 16.

We will verify it for TP 16.

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Mar 5, 2025

@mpjlu Could you help test the latest commit again?

@Ying1123 Ying1123 changed the title Refactor eagle speculative decoding [Eagle] Support prefix caching and chunked prefill for eagle speculative decoding Mar 5, 2025
@Ying1123 Ying1123 changed the title [Eagle] Support prefix caching and chunked prefill for eagle speculative decoding [Eagle] Refactor eagle speculative decoding Mar 5, 2025
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re @xiezhq-hermann: Let us merge this first to reduce the code divergence. Feel free to refactor it later with a better design.

@zhyncs zhyncs merged commit d3d4d76 into main Mar 5, 2025
33 of 36 checks passed
@zhyncs zhyncs deleted the ying-eagle branch March 5, 2025 16:06

class BaseTokenToKVPool:
class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "A memory pool that stores free slots in the kv cache data"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants