diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2efd0de92a7b..6d8126860f01 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -546,6 +546,8 @@ def __init__( self.host_hit_length = 0 # The node to lock until for swa radix tree lock ref self.swa_uuid_for_lock: Optional[int] = None + # The prefix length of the last prefix matching + self.last_matched_prefix_len: int = 0 # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is @@ -700,6 +702,7 @@ def init_next_round_input( token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key ), ) + self.last_matched_prefix_len = len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 893a0b0a1fdf..d62c7f01c0bf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -756,6 +756,7 @@ def init_memory_pool_and_cache(self): disable=server_args.disable_radix_cache, enable_kv_cache_events=self.enable_kv_cache_events, eviction_policy=server_args.radix_eviction_policy, + is_eagle=self.spec_algorithm.is_eagle(), ) if ( diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index abb9445f8fdc..2f818770a0a7 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -23,7 +23,7 @@ import time from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union import torch @@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1): return (key.extra_key, plain_key) +def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]: + # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target + # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)] + if len(tokens) < 2: + return [] + if isinstance(tokens[0], tuple): + return tokens + return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] + + class RadixCache(BasePrefixCache): def __init__( self, @@ -168,6 +178,7 @@ def __init__( disable: bool = False, enable_kv_cache_events: bool = False, eviction_policy: str = "lru", + is_eagle: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator @@ -175,6 +186,7 @@ def __init__( self.disable = disable self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue = [] + self.is_eagle = is_eagle if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device @@ -188,6 +200,11 @@ def __init__( self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.get_child_key_fn = partial(get_child_key, page_size=page_size) + if is_eagle: + self.key_convert_fn = _convert_to_bigram_key + else: + self.key_convert_fn = lambda key: key + if eviction_policy.lower() == "lru": self.eviction_strategy: EvictionStrategy = LRUStrategy() elif eviction_policy.lower() == "lfu": @@ -248,6 +265,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: to expose a precise boundary; this structural refinement improves subsequent match efficiency and does not duplicate data. """ + key.token_ids = self.key_convert_fn(key.token_ids) + if self.disable or len(key) == 0: return MatchResult( device_indices=torch.empty( @@ -278,8 +297,15 @@ def insert(self, key: RadixKey, value=None, chunked=False): if self.disable: return 0 + key.token_ids = self.key_convert_fn(key.token_ids) + if value is None: value = torch.tensor(key.token_ids, dtype=torch.int64) + + if self.is_eagle: + # Make sure the value len equal to the EAGLE bigram key len + value = value[: len(key)] + return self._insert_helper(self.root_node, key, value) def cache_finished_req(self, req: Req): @@ -293,28 +319,39 @@ def cache_finished_req(self, req: Req): return token_ids = (req.origin_input_ids + req.output_ids)[:-1] + all_token_len = len(token_ids) + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + if self.is_eagle: + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) + + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( - RadixKey(token_ids[:page_aligned_len], req.extra_key), + RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, ) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) @@ -326,19 +363,32 @@ def cache_unfinished_req(self, req: Req, chunked=False): return token_ids = req.fill_ids + all_token_len = len(token_ids) + # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) - page_aligned_token_ids = token_ids[:page_aligned_len] + + # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + page_aligned_token_ids = token_ids[:page_aligned_token_len] + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( @@ -346,29 +396,40 @@ def cache_unfinished_req(self, req: Req, chunked=False): page_aligned_kv_indices, chunked=chunked, ) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) ) self.req_to_token_pool.write( - (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), - new_indices[len(req.prefix_indices) :], + (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), + new_indices[old_prefix_len:], ) + # The last_matched_prefix_len is not always equal to len(req.prefix_indices) + # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree. + # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak. + # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly. + req.last_matched_prefix_len = len(new_indices) + self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later if self.page_size != 1: + # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req. req.prefix_indices = torch.cat( [new_indices, kv_indices[len(new_indices) :]] ) else: - req.prefix_indices = new_indices + if self.is_eagle: + # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill + req.prefix_indices = torch.cat( + [new_indices, kv_indices[actual_kv_len:]] + ) + else: + req.prefix_indices = new_indices req.last_node = new_last_node def pretty_print(self): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2e9a16896b33..360c852cb876 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -77,7 +77,8 @@ # EAGLE DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B" -DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" +DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( "meta-llama/Llama-3.1-8B-Instruct" ) diff --git a/test/srt/test_eagle_infer_a.py b/test/srt/test_eagle_infer_a.py index c19f0c22f082..f956059c0d29 100644 --- a/test/srt/test_eagle_infer_a.py +++ b/test/srt/test_eagle_infer_a.py @@ -9,6 +9,8 @@ from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase): } NUM_CONFIGS = 2 + THRESHOLDS = { + "batch_avg_accept_len": 1.9, + "accept_len": 3.6, + } + def setUp(self): self.prompt = "Today is a sunny day and I like" self.sampling_params = {"temperature": 0, "max_new_tokens": 8} @@ -63,6 +70,7 @@ def test_correctness(self): self._test_eos_token(engine) self._test_acc_length(engine) finally: + engine.flush_cache() # check engine alive engine.shutdown() print("=" * 100) @@ -92,7 +100,9 @@ def _test_batch_generation(self, engine): "avg_spec_accept_length" ] print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.9) + self.assertGreater( + avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] + ) def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" @@ -131,10 +141,7 @@ def _test_acc_length(self, engine): ) print(f"{acc_length=:.4f}, {speed=}") - if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: - self.assertGreater(acc_length, 3.6) - else: - self.assertGreater(acc_length, 2.5) + self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) class TestEAGLEEngineTokenMap(TestEAGLEEngine): @@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine): "dtype": "float16", } NUM_CONFIGS = 1 + THRESHOLDS = { + "batch_avg_accept_len": 1.9, + "accept_len": 2.5, + } class TestEAGLE3Engine(TestEAGLEEngine): BASE_CONFIG = { - "model_path": "meta-llama/Llama-3.1-8B-Instruct", - "speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + "speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "speculative_algorithm": "EAGLE3", "speculative_num_steps": 5, "speculative_eagle_topk": 16, @@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine): "dtype": "float16", } NUM_CONFIGS = 1 + THRESHOLDS = { + "batch_avg_accept_len": 1.75, + "accept_len": 3.1, + } + + +class TestEAGLERadixCache(CustomTestCase): + BASE_CONFIG = { + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + "speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + "speculative_algorithm": "EAGLE3", + "speculative_num_steps": 2, + "speculative_eagle_topk": 1, + "speculative_num_draft_tokens": 3, + "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 5, + "dtype": "float16", + } + + def test_correctness(self): + configs = [ + # Basic config + self.BASE_CONFIG, + # Chunked prefill + {**self.BASE_CONFIG, "chunked_prefill_size": 64}, + # Chunked prefill & Page Size > 1 + {**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4}, + ] + + for i, config in enumerate(configs): + with self.subTest(i=i): + print(f"{config=}") + engine = sgl.Engine(**config, log_level="info", decode_log_interval=10) + try: + self._test_acc_length(engine) + finally: + engine.shutdown() + print("=" * 100) + + def _test_acc_length(self, engine): + warmup_prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 512} + output = engine.generate(warmup_prompt, sampling_params) + test_prompt = [ + "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ] + output = engine.generate(test_prompt, sampling_params) + output = output[0] + + if "spec_verify_ct" in output["meta_info"]: + acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["e2e_latency"] + ) + print(f"{acc_length=:.4f}, {speed=}") + + self.assertGreater(acc_length, 2.5) @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") diff --git a/test/srt/test_radix_cache_unit.py b/test/srt/test_radix_cache_unit.py index 8cb75fb0bf84..f8708eaf387e 100644 --- a/test/srt/test_radix_cache_unit.py +++ b/test/srt/test_radix_cache_unit.py @@ -307,6 +307,72 @@ def test_insert_and_match_basic(self): result.device_indices, torch.tensor([10, 20], dtype=torch.int64) ) + def test_insert_and_match_eagle(self): + """Test insert and match operations for EAGLE.""" + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=1, + disable=False, + is_eagle=True, + ) + + key = RadixKey([1, 2, 3, 4]) + value = torch.tensor([10, 20, 30, 40], dtype=torch.int64) + prefix_len = cache.insert(key, value) + + self.assertEqual(prefix_len, 0) # No existing prefix + self.assertEqual( + cache.total_size(), 3 + ) # The last token is ignored in bigram key + self.assertEqual(cache.evictable_size(), 3) + + # Test match_prefix + result = cache.match_prefix(RadixKey([1, 2, 3, 4])) + self.assertEqual(len(result.device_indices), 3) + torch.testing.assert_close( + result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64) + ) + + # Test partial match + result = cache.match_prefix(RadixKey([1, 2])) + self.assertEqual(len(result.device_indices), 1) + torch.testing.assert_close( + result.device_indices, torch.tensor([10], dtype=torch.int64) + ) + + def test_insert_and_match_eagle_page_size(self): + """Test insert and match operations for EAGLE and page_size > 1.""" + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=2, + disable=False, + is_eagle=True, + ) + + key = RadixKey([1, 2, 3]) + value = torch.tensor([10, 20, 30], dtype=torch.int64) + prefix_len = cache.insert(key, value) + + self.assertEqual(prefix_len, 0) # No existing prefix + self.assertEqual(cache.total_size(), 2) # only one page is inserted + self.assertEqual(cache.evictable_size(), 2) + + # Test match_prefix + result = cache.match_prefix(RadixKey([1, 2, 3, 4])) + self.assertEqual(len(result.device_indices), 2) + torch.testing.assert_close( + result.device_indices, torch.tensor([10, 20], dtype=torch.int64) + ) + + # Test unmatched + result = cache.match_prefix(RadixKey([1, 2])) + self.assertEqual(len(result.device_indices), 0) + torch.testing.assert_close( + result.device_indices, torch.tensor([], dtype=torch.int64) + ) + def test_insert_with_none_value(self): """Test insert with None value (should use token_ids as list).""" cache = RadixCache(