Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def __init__(
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# The prefix length of the last prefix matching
self.last_matched_prefix_len: int = 0

# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
Expand Down Expand Up @@ -700,6 +702,7 @@ def init_next_round_input(
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def init_memory_pool_and_cache(self):
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
is_eagle=self.spec_algorithm.is_eagle(),
)

if (
Expand Down
97 changes: 79 additions & 18 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
return (key.extra_key, plain_key)


def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
if len(tokens) < 2:
return []
if isinstance(tokens[0], tuple):
return tokens
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]


class RadixCache(BasePrefixCache):
def __init__(
self,
Expand All @@ -168,13 +178,15 @@ def __init__(
disable: bool = False,
enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
is_eagle: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
self.is_eagle = is_eagle

if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
Expand All @@ -188,6 +200,11 @@ def __init__(
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)

if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key

if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy()
elif eviction_policy.lower() == "lfu":
Expand Down Expand Up @@ -248,6 +265,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.
"""
key.token_ids = self.key_convert_fn(key.token_ids)

if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
Expand Down Expand Up @@ -278,8 +297,15 @@ def insert(self, key: RadixKey, value=None, chunked=False):
if self.disable:
return 0

key.token_ids = self.key_convert_fn(key.token_ids)

if value is None:
value = torch.tensor(key.token_ids, dtype=torch.int64)

if self.is_eagle:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hiradix (and other trees like swa) override the insert function, would that be a problem since eagle worker shared the same tree?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in current design, we need to adapt this change to other trees like swa and hiradix if they override these functions. This PR just makes the main radix tree ready. HiCache and swa need extra work and test to make them ready.

# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]

return self._insert_helper(self.root_node, key, value)

def cache_finished_req(self, req: Req):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while hiradix does not, swa tree override this implementation as well

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The swa cache inherits from BaseRadixCache, so it seems all the changes should be implemented again on it. HiCache is from RadixCache, we just need to do some adaptation on it with less override. But for HiCache, the main thing I'm concerning is that the chunked prefill size is a little changed. If the chunked prefill size is 64, actually only 63 bigram keys are inserted to the tree. Maybe it's not efficient for cache offloading with block.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is what we are doing is primarily to resolve conflict with eagle workers since it shares the same radix tree but has its own pool, but not to have hicache support for eagle workers, i.e., eagle workers to fetch kv caches from host memory, which seems unnecessary and potentially complicated. Is it correct?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the kv cache for eagle worker is unnecessary to store into host memory since it's only one layer. If we use HiCache only for target model, can we still share the kv indices between target and draft pool?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think it should be fine just wanted to confirm that we are aligned on this

Copy link
Copy Markdown
Collaborator Author

@ispobock ispobock Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand All @@ -293,28 +319,39 @@ def cache_finished_req(self, req: Req):
return

token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]

if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])

page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1

# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])

# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
Expand All @@ -326,49 +363,73 @@ def cache_unfinished_req(self, req: Req, chunked=False):
return

token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]

if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]

# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1

# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
chunked=chunked,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)

# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
req.last_matched_prefix_len = len(new_indices)

self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)

# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node

def pretty_print(self):
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
# EAGLE
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)
Expand Down
91 changes: 84 additions & 7 deletions test/srt/test_eagle_infer_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
Expand All @@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase):
}
NUM_CONFIGS = 2

THRESHOLDS = {
"batch_avg_accept_len": 1.9,
"accept_len": 3.6,
}

def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
Expand Down Expand Up @@ -63,6 +70,7 @@ def test_correctness(self):
self._test_eos_token(engine)
self._test_acc_length(engine)
finally:
engine.flush_cache() # check engine alive
engine.shutdown()
print("=" * 100)

Expand Down Expand Up @@ -92,7 +100,9 @@ def _test_batch_generation(self, engine):
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.9)
self.assertGreater(
avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"]
)

def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
Expand Down Expand Up @@ -131,10 +141,7 @@ def _test_acc_length(self, engine):
)
print(f"{acc_length=:.4f}, {speed=}")

if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
self.assertGreater(acc_length, 3.6)
else:
self.assertGreater(acc_length, 2.5)
self.assertGreater(acc_length, self.THRESHOLDS["accept_len"])


class TestEAGLEEngineTokenMap(TestEAGLEEngine):
Expand All @@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine):
"dtype": "float16",
}
NUM_CONFIGS = 1
THRESHOLDS = {
"batch_avg_accept_len": 1.9,
"accept_len": 2.5,
}


class TestEAGLE3Engine(TestEAGLEEngine):
BASE_CONFIG = {
"model_path": "meta-llama/Llama-3.1-8B-Instruct",
"speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"speculative_algorithm": "EAGLE3",
"speculative_num_steps": 5,
"speculative_eagle_topk": 16,
Expand All @@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine):
"dtype": "float16",
}
NUM_CONFIGS = 1
THRESHOLDS = {
"batch_avg_accept_len": 1.75,
"accept_len": 3.1,
}


class TestEAGLERadixCache(CustomTestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"speculative_algorithm": "EAGLE3",
"speculative_num_steps": 2,
"speculative_eagle_topk": 1,
"speculative_num_draft_tokens": 3,
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
"dtype": "float16",
}

def test_correctness(self):
configs = [
# Basic config
self.BASE_CONFIG,
# Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 64},
# Chunked prefill & Page Size > 1
{**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4},
]

for i, config in enumerate(configs):
with self.subTest(i=i):
print(f"{config=}")
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
self._test_acc_length(engine)
finally:
engine.shutdown()
print("=" * 100)

def _test_acc_length(self, engine):
warmup_prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(warmup_prompt, sampling_params)
test_prompt = [
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
]
output = engine.generate(test_prompt, sampling_params)
output = output[0]

if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0

speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=:.4f}, {speed=}")

self.assertGreater(acc_length, 2.5)


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
Expand Down
Loading
Loading