From 4926b37921d604ec49b54acb45417ece3e68a813 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 23 Mar 2026 19:59:31 +0000 Subject: [PATCH 1/7] remove min max match window --- docs/advanced_features/server_arguments.md | 4 +-- .../advanced_features/speculative_decoding.md | 11 ++----- python/sglang/srt/server_args.py | 18 ++--------- .../srt/speculative/cpp_ngram/ngram.cpp | 31 ------------------- .../srt/speculative/cpp_ngram/ngram_corpus.py | 4 --- .../cpp_ngram/ngram_corpus_binding.cpp | 5 --- .../sglang/srt/speculative/cpp_ngram/param.h | 22 ------------- .../sglang/srt/speculative/cpp_ngram/trie.cpp | 29 +++++++++-------- .../sglang/srt/speculative/cpp_ngram/trie.h | 3 +- python/sglang/srt/speculative/ngram_worker.py | 7 +---- python/sglang/test/lora_utils.py | 2 -- python/sglang/test/runners.py | 9 +----- .../spec/utils/test_ngram_corpus.py | 27 +++++++++------- 13 files changed, 38 insertions(+), 134 deletions(-) diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 61cfe91e07c6..46f5e712cac4 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -294,12 +294,10 @@ Please consult the documentation below and [server_args.py](https://github.com/s ## Ngram speculative decoding | Argument | Description | Defaults | Options | | --- | --- | --- | --- | -| `--speculative-ngram-min-match-window-size` | The minimum window size for pattern matching in ngram speculative decoding. | `1` | Type: int | -| `--speculative-ngram-max-match-window-size` | The maximum window size for pattern matching in ngram speculative decoding. | `12` | Type: int | | `--speculative-ngram-min-bfs-breadth` | The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `1` | Type: int | | `--speculative-ngram-max-bfs-breadth` | The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `10` | Type: int | | `--speculative-ngram-match-type` | Ngram tree-building mode. `BFS` selects recency-based expansion and `PROB` selects frequency-based expansion. This setting is forwarded to the ngram cache implementation. | `BFS` | `BFS`, `PROB` | -| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | Type: int | +| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` | Type: int | | `--speculative-ngram-capacity` | The cache capacity for ngram speculative decoding. | `10000000` | Type: int | ## Multi-layer Eagle speculative decoding diff --git a/docs/advanced_features/speculative_decoding.md b/docs/advanced_features/speculative_decoding.md index c573af0724a8..b8fe2d890cc4 100644 --- a/docs/advanced_features/speculative_decoding.md +++ b/docs/advanced_features/speculative_decoding.md @@ -387,13 +387,11 @@ Enable it with: | Parameter | Description | Default | |---|---|---| -| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `--speculative-ngram-max-match-window-size`. | `12` (with default ngram settings) | -| `--speculative-ngram-min-match-window-size` | Minimum matching window size. | `1` | -| `--speculative-ngram-max-match-window-size` | Maximum matching window size. | `12` | +| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `min(--speculative-ngram-max-trie-depth, 12)`. | `12` (with default ngram settings) | | `--speculative-ngram-min-bfs-breadth` | Minimum BFS breadth. | `1` | | `--speculative-ngram-max-bfs-breadth` | Maximum BFS breadth. | `10` | | `--speculative-ngram-match-type` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion. | `"BFS"` | -| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | +| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` | | `--speculative-ngram-capacity` | Cache capacity (number of entries). | `10,000,000` | Notes: @@ -408,7 +406,6 @@ python3 -m sglang.launch_server \ --model Qwen/Qwen2.5-7B-Instruct \ --speculative-algorithm NGRAM \ --speculative-num-draft-tokens 16 \ - --speculative-ngram-max-match-window-size 12 \ --speculative-ngram-max-bfs-breadth 10 \ --mem-fraction-static 0.7 \ --cuda-graph-max-bs 8 \ @@ -464,12 +461,10 @@ Below is a comprehensive list of all speculative decoding parameters available i | Parameter | Type | Default | Description | |---|---|---|---| -| `--speculative-ngram-min-match-window-size` | `int` | `1` | Minimum ngram matching window | -| `--speculative-ngram-max-match-window-size` | `int` | `12` | Maximum ngram matching window | | `--speculative-ngram-min-bfs-breadth` | `int` | `1` | Minimum BFS breadth | | `--speculative-ngram-max-bfs-breadth` | `int` | `10` | Maximum BFS breadth | | `--speculative-ngram-match-type` | `str` | `"BFS"` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion | -| `--speculative-ngram-max-trie-depth` | `int` | `18` | Max trie depth for ngram speculative decoding | +| `--speculative-ngram-max-trie-depth` | `int` | `18` | Maximum suffix length stored and matched by the ngram trie | | `--speculative-ngram-capacity` | `int` | `10,000,000` | Cache capacity | ### Environment variables diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f94da569a53b..b123fb34dc0b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -501,8 +501,6 @@ class ServerArgs: speculative_draft_model_quantization: Optional[str] = None # Speculative decoding (ngram) - speculative_ngram_min_match_window_size: int = 1 - speculative_ngram_max_match_window_size: int = 12 speculative_ngram_min_bfs_breadth: int = 1 speculative_ngram_max_bfs_breadth: int = 10 speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS" @@ -3049,8 +3047,8 @@ def _handle_speculative_decoding(self): self.enable_mixed_chunk = False self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth if self.speculative_num_draft_tokens is None: - self.speculative_num_draft_tokens = ( - self.speculative_ngram_max_match_window_size + self.speculative_num_draft_tokens = min( + self.speculative_ngram_max_trie_depth, 12 ) logger.warning( "The overlap scheduler and mixed chunked prefill are disabled because of " @@ -4734,18 +4732,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # Speculative decoding (ngram) - parser.add_argument( - "--speculative-ngram-min-match-window-size", - type=int, - default=ServerArgs.speculative_ngram_min_match_window_size, - help="The minimum window size for pattern matching in ngram speculative decoding.", - ) - parser.add_argument( - "--speculative-ngram-max-match-window-size", - type=int, - default=ServerArgs.speculative_ngram_max_match_window_size, - help="The maximum window size for pattern matching in ngram speculative decoding.", - ) parser.add_argument( "--speculative-ngram-min-bfs-breadth", type=int, diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp index b1d54b964400..904782774916 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp @@ -13,23 +13,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { throw std::runtime_error( "param_.max_trie_depth must be greater than 1, current value: " + std::to_string(param_.max_trie_depth)); } - if (!(param_.min_match_window_size > 0)) { - throw std::runtime_error( - "min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size)); - } - if (!(param_.min_match_window_size <= param_.max_match_window_size)) { - throw std::runtime_error( - "min_match_window_size must be less than or equal to " - "max_match_window_size, current min_match_window_size: " + - std::to_string(param_.min_match_window_size) + - ", max_match_window_size: " + std::to_string(param_.max_match_window_size)); - } - if (!(param_.max_match_window_size < param_.max_trie_depth)) { - throw std::runtime_error( - "max_match_window_size must be less than max_trie_depth, current " - "max_match_window_size: " + - std::to_string(param_.max_match_window_size) + ", max_trie_depth: " + std::to_string(param_.max_trie_depth)); - } if (!(param_.min_bfs_breadth > 0)) { throw std::runtime_error( "min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth)); @@ -53,20 +36,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) { } } } - for (auto config : param_.batch_min_match_window_size) { - if (config != std::numeric_limits::max()) { - if (!(config >= param_.min_match_window_size)) { - throw std::runtime_error( - "batch_min_match_window_size config value " + std::to_string(config) + - " must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size)); - } - if (!(config <= param_.max_match_window_size)) { - throw std::runtime_error( - "batch_min_match_window_size config value " + std::to_string(config) + - " must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size)); - } - } - } trie_ = std::make_unique(capacity, param_); diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index e44a3da6b2ec..f35e9acf95fe 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -26,8 +26,6 @@ class NgramCorpus: def __init__( self, max_trie_depth=18, - min_match_window_size=1, - max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, @@ -36,8 +34,6 @@ def __init__( ): param = ngram_corpus_cpp.Param() param.max_trie_depth = max_trie_depth - param.min_match_window_size = min_match_window_size - param.max_match_window_size = max_match_window_size param.min_bfs_breadth = min_bfs_breadth param.max_bfs_breadth = max_bfs_breadth param.draft_token_num = draft_token_num diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp index 8da395440293..e632dfb3de59 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp @@ -21,17 +21,12 @@ PYBIND11_MODULE(ngram_corpus_cpp, m) { .def_readwrite("enable_router_mode", &Param::enable_router_mode) .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth) .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth) - .def_readwrite("min_match_window_size", &Param::min_match_window_size) - .def_readwrite("max_match_window_size", &Param::max_match_window_size) .def_readwrite("max_trie_depth", &Param::max_trie_depth) .def_readwrite("draft_token_num", &Param::draft_token_num) .def_readwrite("match_type", &Param::match_type) - .def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size) .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num) .def("get_draft_token_num", &Param::get_draft_token_num, "") - .def("get_min_match_window_size", &Param::get_min_match_window_size, "") .def("parse", &Param::parse, "") - .def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "") .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") .def("detail", &Param::detail, ""); diff --git a/python/sglang/srt/speculative/cpp_ngram/param.h b/python/sglang/srt/speculative/cpp_ngram/param.h index d31af64ba5b9..725f635db8cd 100644 --- a/python/sglang/srt/speculative/cpp_ngram/param.h +++ b/python/sglang/srt/speculative/cpp_ngram/param.h @@ -17,13 +17,10 @@ struct Param { bool enable_router_mode; size_t min_bfs_breadth; size_t max_bfs_breadth; - size_t min_match_window_size; - size_t max_match_window_size; size_t max_trie_depth; size_t draft_token_num; std::string match_type; - std::vector batch_min_match_window_size; std::vector batch_draft_token_num; size_t get_draft_token_num(size_t batch_size) const { @@ -36,16 +33,6 @@ struct Param { return draft_token_num - 1; } - size_t get_min_match_window_size(size_t batch_size) const { - if (batch_size < batch_min_match_window_size.size()) { - if (batch_min_match_window_size[batch_size] != - std::numeric_limits::max()) { - return batch_min_match_window_size[batch_size]; - } - } - return min_match_window_size; - } - std::vector parse(const std::string& value) { // 0-1|10,2-3|20, std::vector result; @@ -96,10 +83,6 @@ struct Param { return result; } - void resetBatchMinMatchWindowSize(const std::string& value) { - batch_min_match_window_size = parse(value); - } - void resetBatchReturnTokenNum(const std::string& value) { batch_draft_token_num = parse(value); } @@ -108,13 +91,8 @@ struct Param { std::stringstream ss; ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth - << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size << ", max_trie_depth = " << max_trie_depth << ", draft_token_num = " << draft_token_num << ", match_type = " << match_type; - ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = "; - for (int i = 0; i < batch_min_match_window_size.size(); ++i) { - ss << i << "|" << batch_min_match_window_size[i] << ","; - } ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = "; for (int i = 0; i < batch_draft_token_num.size(); ++i) { ss << i << "|" << batch_draft_token_num[i] << ","; diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/srt/speculative/cpp_ngram/trie.cpp index 8d9eec82b97e..37e391485ffd 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/trie.cpp @@ -19,7 +19,7 @@ Trie::Trie(size_t capacity, const Param& param) : param_(param) { } void Trie::insert(const int32_t* tokens, size_t len) { - for (size_t i = 0; i + param_.min_match_window_size < len; ++i) { + for (size_t i = 0; i + 1 < len; ++i) { auto start = tokens + i; auto end = start + std::min(len - i, param_.max_trie_depth); @@ -100,14 +100,13 @@ void Trie::reset() { root_ = getNode(); } -std::vector> -Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const { +std::vector> Trie::match(const int32_t* context, size_t len) const { std::vector> result; - result.reserve(max_window - min_window); - for (int32_t match_window_size = std::min(len, max_window); match_window_size >= static_cast(min_window); - --match_window_size) { - auto start = context + len - match_window_size; - auto end = start + match_window_size; + const auto max_match_depth = std::min(len, param_.max_trie_depth); + result.reserve(max_match_depth); + for (size_t match_depth = max_match_depth; match_depth > 0; --match_depth) { + auto start = context + len - match_depth; + auto end = start + match_depth; auto cursor = root_; while (start != end) { auto iter = cursor->child.find(*start); @@ -118,8 +117,8 @@ Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_wi ++start; cursor = iter->second; } - if (cursor) { - result.emplace_back(std::make_pair(cursor, match_window_size)); + if (cursor != nullptr && !cursor->child.empty()) { + result.emplace_back(cursor, static_cast(match_depth)); } } return result; @@ -127,10 +126,10 @@ Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_wi Result Trie::buildRecency( const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size); + auto anchors = match(context, len); - double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / - (param.max_match_window_size - param.min_match_window_size + 1); + const auto max_match_depth = std::max(1, static_cast(param.max_trie_depth - 1)); + double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / max_match_depth; std::vector tree(draft_token_num + 1); int root = 0; @@ -138,7 +137,7 @@ Result Trie::buildRecency( for (auto [node, depth] : anchors) { std::queue> queue; - queue.push({root, (param.max_match_window_size - depth) * bfs_breadth_scale + param.min_bfs_breadth, node}); + queue.push({root, (max_match_depth - depth) * bfs_breadth_scale + param.min_bfs_breadth, node}); while (queue.size() && cursor <= static_cast(draft_token_num)) { auto front = queue.front(); queue.pop(); @@ -168,7 +167,7 @@ Result Trie::buildRecency( Result Trie::buildFrequency( const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size); + auto anchors = match(context, len); struct CompareByLastDouble { bool operator()( diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/srt/speculative/cpp_ngram/trie.h index 30db5b29400c..41fd6e54ceb2 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/srt/speculative/cpp_ngram/trie.h @@ -49,8 +49,7 @@ class Trie { void reset(); private: - std::vector> - match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const; + std::vector> match(const int32_t* context, size_t len) const; TrieNode* getNode() { auto node = node_pool_[--free_node_count_]; diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 04a38cefbb83..8c108915c939 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -41,9 +41,6 @@ def __init__( self.page_size = server_args.page_size self.draft_token_num: int = server_args.speculative_num_draft_tokens self.max_trie_depth: int = server_args.speculative_ngram_max_trie_depth - self.max_match_window_size: int = ( - server_args.speculative_ngram_max_match_window_size - ) self.max_batch_size = target_worker.max_running_requests self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda" @@ -51,8 +48,6 @@ def __init__( self._init_preallocated_tensors() self.ngram_corpus = NgramCorpus( - min_match_window_size=server_args.speculative_ngram_min_match_window_size, - max_match_window_size=server_args.speculative_ngram_max_match_window_size, min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth, max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth, match_type=server_args.speculative_ngram_match_type, @@ -131,7 +126,7 @@ def _prepare_draft_tokens( batch_tokens = [] for req in batch.reqs: check_token = self._efficient_concat_last_n( - req.origin_input_ids, req.output_ids, self.max_match_window_size + req.origin_input_ids, req.output_ids, self.max_trie_depth ) batch_tokens.append(check_token) req_drafts, mask = self.ngram_corpus.batch_get(batch_tokens) diff --git a/python/sglang/test/lora_utils.py b/python/sglang/test/lora_utils.py index 634974f2fd28..0a14b1743d0f 100644 --- a/python/sglang/test/lora_utils.py +++ b/python/sglang/test/lora_utils.py @@ -738,8 +738,6 @@ def run_lora_multiple_batch_on_model_cases( else { "speculative_algorithm": "NGRAM", "speculative_num_draft_tokens": 5, - "speculative_ngram_min_match_window_size": 2, - "speculative_ngram_max_match_window_size": 15, } ) srt_runner = SRTRunner( diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 61781fea21de..b78910bd55d9 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -564,8 +564,6 @@ def __init__( speculative_num_steps: Optional[int] = None, speculative_eagle_topk: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None, - speculative_ngram_min_match_window_size: Optional[int] = None, - speculative_ngram_max_match_window_size: Optional[int] = None, disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, torchao_config: Optional[str] = None, @@ -596,12 +594,7 @@ def __init__( spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens elif speculative_algorithm == "NGRAM": spec_kwargs["speculative_algorithm"] = speculative_algorithm - spec_kwargs["speculative_ngram_min_match_window_size"] = ( - speculative_ngram_min_match_window_size - ) - spec_kwargs["speculative_ngram_max_match_window_size"] = ( - speculative_ngram_max_match_window_size - ) + spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens self.engine = Engine( model_path=model_path, diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 6f2427a40966..4e282724c5cd 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -12,8 +12,6 @@ def _make_corpus(match_type="BFS", **kwargs): defaults = dict( max_trie_depth=12, - min_match_window_size=1, - max_match_window_size=10, min_bfs_breadth=1, max_bfs_breadth=8, draft_token_num=8, @@ -239,9 +237,7 @@ def test_small_capacity_does_not_crash(self): self.assertEqual(len(ids), 8, "Should still produce draft_token_num outputs") def test_eviction_preserves_recent(self): - corpus = _make_corpus( - "BFS", capacity=500, max_trie_depth=6, max_match_window_size=5 - ) + corpus = _make_corpus("BFS", capacity=500, max_trie_depth=6) old_seq = list(range(1000, 1050)) corpus.batch_put([old_seq]) @@ -357,7 +353,6 @@ def test_repeated_insert_promotes_token(self): draft_token_num=2, max_bfs_breadth=1, min_bfs_breadth=1, - max_match_window_size=3, max_trie_depth=5, ) corpus.batch_put([[1, 2, 3, 10, 11]]) @@ -386,7 +381,6 @@ def test_most_recent_insert_selected(self): draft_token_num=2, max_bfs_breadth=1, min_bfs_breadth=1, - max_match_window_size=3, max_trie_depth=5, ) corpus.batch_put([[1, 2, 3, 10, 11]]) @@ -422,7 +416,7 @@ class TestSingleTokenContext(CustomTestCase): """Verify behavior with minimum-length context.""" def test_single_token_query(self): - corpus = _make_corpus("BFS", min_match_window_size=1) + corpus = _make_corpus("BFS") corpus.batch_put([[5, 10, 20, 30]]) corpus.synchronize() @@ -436,7 +430,7 @@ class TestLongContext(CustomTestCase): """Verify behavior when query context exceeds max_trie_depth.""" def test_context_longer_than_max_trie_depth(self): - corpus = _make_corpus("BFS", max_trie_depth=6, max_match_window_size=5) + corpus = _make_corpus("BFS", max_trie_depth=6) seq = list(range(1, 20)) corpus.batch_put([seq]) corpus.synchronize() @@ -447,6 +441,17 @@ def test_context_longer_than_max_trie_depth(self): self.assertEqual(ids_list[0], 15, "First token should be last context token") self.assertIn(16, ids_list, "Should match via suffix despite long context") + def test_matches_longest_stored_suffix(self): + corpus = _make_corpus("BFS", max_trie_depth=6, draft_token_num=4) + corpus.batch_put([[1, 2, 3, 4, 5, 6, 7]]) + corpus.batch_put([[99, 3, 4, 5, 6, 8]]) + corpus.synchronize() + + ids, _ = corpus.batch_get([[2, 3, 4, 5, 6]]) + ids_list = ids.tolist() + self.assertIn(7, ids_list, "Longest stored suffix should contribute a continuation") + self.assertIn(8, ids_list, "Shorter matching suffixes should still contribute continuations") + class TestDraftBudgetSaturation(CustomTestCase): """Verify the draft tree uses exactly draft_token_num slots.""" @@ -538,9 +543,7 @@ class TestSqueezeEvictsOld(CustomTestCase): """Verify that squeeze actually evicts old data, not just preserves recent.""" def test_old_data_evicted(self): - corpus = _make_corpus( - "BFS", capacity=150, max_trie_depth=6, max_match_window_size=5 - ) + corpus = _make_corpus("BFS", capacity=150, max_trie_depth=6) old_seq = list(range(5000, 5030)) corpus.batch_put([old_seq]) From d30835f62e2e56be352790f61bc58f7893f7ef48 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 23 Mar 2026 20:03:30 +0000 Subject: [PATCH 2/7] lint --- test/registered/spec/utils/test_ngram_corpus.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 4e282724c5cd..1605c23f1878 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -449,8 +449,14 @@ def test_matches_longest_stored_suffix(self): ids, _ = corpus.batch_get([[2, 3, 4, 5, 6]]) ids_list = ids.tolist() - self.assertIn(7, ids_list, "Longest stored suffix should contribute a continuation") - self.assertIn(8, ids_list, "Shorter matching suffixes should still contribute continuations") + self.assertIn( + 7, ids_list, "Longest stored suffix should contribute a continuation" + ) + self.assertIn( + 8, + ids_list, + "Shorter matching suffixes should still contribute continuations", + ) class TestDraftBudgetSaturation(CustomTestCase): From 36ed24ce146f11f36e871c848016aa6ddf72edf8 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Mon, 23 Mar 2026 20:09:01 +0000 Subject: [PATCH 3/7] misc --- python/sglang/srt/server_args.py | 6 ++++-- python/sglang/srt/speculative/cpp_ngram/trie.cpp | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b123fb34dc0b..3e898df0e245 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3047,8 +3047,10 @@ def _handle_speculative_decoding(self): self.enable_mixed_chunk = False self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth if self.speculative_num_draft_tokens is None: - self.speculative_num_draft_tokens = min( - self.speculative_ngram_max_trie_depth, 12 + self.speculative_num_draft_tokens = 12 + logger.warning( + "speculative_num_draft_tokens is set to 12 by default for ngram speculative decoding. " + "You can override this by explicitly setting --speculative-num-draft-tokens." ) logger.warning( "The overlap scheduler and mixed chunked prefill are disabled because of " diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/srt/speculative/cpp_ngram/trie.cpp index 37e391485ffd..67058eccb589 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/trie.cpp @@ -19,7 +19,7 @@ Trie::Trie(size_t capacity, const Param& param) : param_(param) { } void Trie::insert(const int32_t* tokens, size_t len) { - for (size_t i = 0; i + 1 < len; ++i) { + for (size_t i = 0; i < len; ++i) { auto start = tokens + i; auto end = start + std::min(len - i, param_.max_trie_depth); From 0d1d60f59addaaf3f037b47417f82b961c5b2dad Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 24 Mar 2026 00:26:31 +0000 Subject: [PATCH 4/7] increment anchor after every decode steps instead of rematching --- .../srt/speculative/cpp_ngram/ngram.cpp | 43 ++++- .../sglang/srt/speculative/cpp_ngram/ngram.h | 12 +- .../srt/speculative/cpp_ngram/ngram_corpus.py | 19 ++- .../cpp_ngram/ngram_corpus_binding.cpp | 1 + .../sglang/srt/speculative/cpp_ngram/trie.cpp | 126 ++++++++++++-- .../sglang/srt/speculative/cpp_ngram/trie.h | 81 ++++++++- python/sglang/srt/speculative/ngram_worker.py | 12 +- .../spec/utils/test_ngram_corpus.py | 156 +++++++++++++++--- 8 files changed, 399 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp index 904782774916..dd9675bc4c8f 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp @@ -78,10 +78,25 @@ void Ngram::insertWorker() { } } -Result Ngram::batchMatch(const std::vector>& tokens) const { +Result Ngram::batchMatch( + const std::vector& req_ids, + const std::vector>& tokens, + const std::vector& total_lens) { + if (req_ids.size() != tokens.size() || req_ids.size() != total_lens.size()) { + throw std::runtime_error( + "batchMatch expects req_ids, tokens, and total_lens to match in size"); + } + std::unique_lock lock(mutex_); - using BuildFn = Result (Trie::*)(const int32_t*, size_t, int32_t, size_t, const Param&) const; + using BuildFn = Result (Trie::*)( + const int32_t*, + size_t, + int32_t, + size_t, + const Param&, + MatchState&, + size_t) const; BuildFn build_fn; if (param_.match_type == "BFS") { build_fn = &Trie::buildRecency; @@ -92,13 +107,33 @@ Result Ngram::batchMatch(const std::vector>& tokens) const } Result merged; - for (const auto& suffix : tokens) { + for (size_t i = 0; i < req_ids.size(); ++i) { + const auto& suffix = tokens[i]; + if (suffix.empty()) { + throw std::runtime_error("batchMatch received an empty token tail"); + } + + auto& state = match_state_[req_ids[i]]; auto draft_token_num = param_.get_draft_token_num(tokens.size()); - auto res = (trie_.get()->*build_fn)(suffix.data(), suffix.size(), suffix.back(), draft_token_num, param_); + auto res = (trie_.get()->*build_fn)( + suffix.data(), + suffix.size(), + suffix.back(), + draft_token_num, + param_, + state, + total_lens[i]); merged.token.insert(merged.token.end(), res.token.begin(), res.token.end()); merged.mask.insert(merged.mask.end(), res.mask.begin(), res.mask.end()); } return merged; } +void Ngram::eraseMatchState(const std::vector& req_ids) { + std::unique_lock lock(mutex_); + for (const auto& req_id : req_ids) { + match_state_.erase(req_id); + } +} + } // namespace ngram diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.h b/python/sglang/srt/speculative/cpp_ngram/ngram.h index 377b481ae3fe..d9de5651acd1 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.h +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.h @@ -5,7 +5,9 @@ #include #include #include +#include #include +#include #include #include "param.h" @@ -29,6 +31,8 @@ class Ngram { size_t pending_count_ = 0; utils::Queue> insert_queue_; std::thread insert_worker_; + // NOTE(kpham-sgl): maps req_id to per-request MatchState. + std::unordered_map match_state_; public: Ngram(size_t capacity, const Param& param); @@ -38,13 +42,19 @@ class Ngram { void asyncInsert(std::vector>&& tokens); - Result batchMatch(const std::vector>& tokens) const; + Result batchMatch( + const std::vector& req_ids, + const std::vector>& tokens, + const std::vector& total_lens); + + void eraseMatchState(const std::vector& req_ids); void reset() { std::unique_lock lock(mutex_); if (trie_) { trie_->reset(); } + match_state_.clear(); } const Param& param() const { diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py index f35e9acf95fe..3c7888cbf642 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py @@ -52,10 +52,18 @@ def synchronize(self): def reset(self): self._ngram.reset() - def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: - result = self._ngram.batchMatch(batch_tokens) + def batch_get( + self, + req_ids: List[str], + batch_tokens: List[List[int]], + total_lens: List[int], + ) -> Tuple[np.ndarray, np.ndarray]: + result = self._ngram.batchMatch(req_ids, batch_tokens, total_lens) return np.array(result.token), np.array(result.mask) + def erase_match_state(self, req_ids: List[str]): + self._ngram.eraseMatchState(req_ids) + def leaf_paths_from_mask( self, tokens: List[int], tree_mask: List[List[int]] ) -> List[List[int]]: @@ -131,6 +139,11 @@ def debug_result( corpus.batch_put(token_ids) corpus.synchronize() - decoding_ids, decoding_masks = corpus.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]]) + queries = [[1, 2, 3], [3, 44], [3, 6, 999]] + decoding_ids, decoding_masks = corpus.batch_get( + req_ids=[f"query-{i}" for i in range(len(queries))], + batch_tokens=queries, + total_lens=[len(q) for q in queries], + ) corpus.debug_result(decoding_ids, decoding_masks) diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp index e632dfb3de59..422f2d30e7d9 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp @@ -12,6 +12,7 @@ PYBIND11_MODULE(ngram_corpus_cpp, m) { .def(py::init(), py::arg("capacity"), py::arg("param")) .def("asyncInsert", &Ngram::asyncInsert, "") .def("batchMatch", &Ngram::batchMatch, "") + .def("eraseMatchState", &Ngram::eraseMatchState, "") .def("reset", &Ngram::reset, "") .def("synchronize", &Ngram::synchronize, ""); diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/srt/speculative/cpp_ngram/trie.cpp index 67058eccb589..47f7cf0a3b6a 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/trie.cpp @@ -84,12 +84,14 @@ void Trie::squeeze(size_t count) { last->parent->lru.erase(last->parent_lru_pos); last->parent->sorted_children.erase(last); last->parent->child.erase(last->token); + retireNode(last); node_pool_[free_node_count_++] = last; } } void Trie::reset() { + ++trie_epoch_; global_lru_.clear(); path_.clear(); node_pool_.clear(); @@ -100,11 +102,31 @@ void Trie::reset() { root_ = getNode(); } -std::vector> Trie::match(const int32_t* context, size_t len) const { - std::vector> result; +const TrieNode* Trie::resolve(const MatchState& state, const NodeRef& ref) const { + if (ref.ptr == nullptr || state.trie_epoch != trie_epoch_ || ref.ptr->version != ref.version) { + return nullptr; + } + return ref.ptr; +} + +bool Trie::_validateMatchState(const MatchState& state) const { + if (state.trie_epoch != trie_epoch_) { + return false; + } + for (const auto& ref : state.anchors) { + if (ref.ptr && !resolve(state, ref)) { + return false; + } + } + return true; +} + +bool Trie::_rebuildMatchState(const int32_t* context, size_t len, MatchState& state, size_t total_len) const { const auto max_match_depth = std::min(len, param_.max_trie_depth); - result.reserve(max_match_depth); - for (size_t match_depth = max_match_depth; match_depth > 0; --match_depth) { + state.trie_epoch = trie_epoch_; + state.processed_total_len = total_len; + state.anchors.assign(max_match_depth, {}); + for (size_t match_depth = 1; match_depth <= max_match_depth; ++match_depth) { auto start = context + len - match_depth; auto end = start + match_depth; auto cursor = root_; @@ -117,17 +139,92 @@ std::vector> Trie::match(const int32_t* context, s ++start; cursor = iter->second; } - if (cursor != nullptr && !cursor->child.empty()) { - result.emplace_back(cursor, static_cast(match_depth)); + if (cursor != nullptr) { + state.anchors[match_depth - 1] = capture(cursor); + } + } + return true; +} + +bool Trie::_advanceMatchState(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const { + if (!_validateMatchState(state)) { + return false; + } + + for (size_t i = 0; i < len; ++i) { + const auto next_depth = std::min(state.anchors.size() + 1, param_.max_trie_depth); + std::vector next(next_depth); + + const auto root_ref = rootRef(); + const auto root = resolve(state, root_ref); + if (root == nullptr) { + return false; + } + if (auto iter = root->child.find(tokens[i]); iter != root->child.end()) { + next[0] = capture(iter->second); + } + + for (size_t depth = 1; depth < next_depth; ++depth) { + const auto& prev_ref = state.anchors[depth - 1]; + if (prev_ref.ptr == nullptr) { + continue; + } + const auto prev_node = resolve(state, prev_ref); + if (prev_node == nullptr) { + return false; + } + if (auto iter = prev_node->child.find(tokens[i]); iter != prev_node->child.end()) { + next[depth] = capture(iter->second); + } + } + + state.anchors.swap(next); + } + + state.processed_total_len = total_len; + return true; +} + +std::vector> Trie::_getExpandableAnchors(const MatchState& state) const { + std::vector> result; + result.reserve(state.anchors.size()); + for (size_t depth = state.anchors.size(); depth > 0; --depth) { + const auto node = resolve(state, state.anchors[depth - 1]); + if (node != nullptr && !node->child.empty()) { + result.emplace_back(node, static_cast(depth)); } } return result; } -Result Trie::buildRecency( - const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len); +std::vector> Trie::match( + const int32_t* context, + size_t len, + MatchState& state, + size_t total_len) const { + const bool has_forward_progress = total_len >= state.processed_total_len; + const auto appended_len = has_forward_progress ? total_len - state.processed_total_len : 0; + const auto expected_prev_depth = std::min(state.processed_total_len, param_.max_trie_depth); + const bool can_advance = state.trie_epoch == trie_epoch_ && has_forward_progress && appended_len <= len && + state.anchors.size() == expected_prev_depth; + + if (can_advance && _advanceMatchState(state, context + len - appended_len, appended_len, total_len)) { + return _getExpandableAnchors(state); + } + _rebuildMatchState(context, len, state, total_len); + return _getExpandableAnchors(state); +} + +Result Trie::buildRecency( + const int32_t* context, + size_t len, + int32_t last_token, + size_t draft_token_num, + const Param& param, + MatchState& state, + size_t total_len) const { + auto anchors = match(context, len, state, total_len); const auto max_match_depth = std::max(1, static_cast(param.max_trie_depth - 1)); double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / max_match_depth; @@ -166,9 +263,14 @@ Result Trie::buildRecency( } Result Trie::buildFrequency( - const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const { - auto anchors = match(context, len); - + const int32_t* context, + size_t len, + int32_t last_token, + size_t draft_token_num, + const Param& param, + MatchState& state, + size_t total_len) const { + auto anchors = match(context, len, state, total_len); struct CompareByLastDouble { bool operator()( const std::tuple& a, diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/srt/speculative/cpp_ngram/trie.h index 41fd6e54ceb2..565ef1dfad37 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/srt/speculative/cpp_ngram/trie.h @@ -23,6 +23,9 @@ struct TrieNode { TrieNode* parent; std::list lru; int32_t freq = 0; + // Logical generation of this TrieNode. retireNode() bumps it before the node + // goes back to the pool so stale NodeRefs fail validation after reuse. + uint64_t version = 1; struct CompareByFreq { bool operator()(TrieNode* a, TrieNode* b) const { @@ -32,6 +35,23 @@ struct TrieNode { std::multiset sorted_children; }; +// By-value handle to a logical trie location, cached in MatchState. +// We cannot cache TrieNode* alone across decode steps: squeeze() may evict a +// node, and getNode() may later recycle the same address for a different node. +struct NodeRef { + TrieNode* ptr = nullptr; + uint64_t version = 0; +}; + +// Per-request cached anchors. anchors[d - 1] caches the trie match for the +// length-d suffix ending at the current last token; processed_total_len records +// the full request length covered by those cached anchors. +struct MatchState { + uint64_t trie_epoch = 0; + size_t processed_total_len = 0; + std::vector anchors; +}; + class Trie { public: Trie(size_t capacity, const Param& param); @@ -39,22 +59,76 @@ class Trie { void insert(const int32_t* tokens, size_t len); Result buildRecency( - const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const; + const int32_t* context, + size_t len, + int32_t last_token, + size_t draft_token_num, + const Param& param, + MatchState& state, + size_t total_len) const; Result buildFrequency( - const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const; + const int32_t* context, + size_t len, + int32_t last_token, + size_t draft_token_num, + const Param& param, + MatchState& state, + size_t total_len) const; void squeeze(size_t count); void reset(); private: - std::vector> match(const int32_t* context, size_t len) const; + // Stateful suffix matcher. If `state` still represents the previous step for + // this request, infer the newly appended suffix from (`context`, `total_len`) + // and advance anchors incrementally; otherwise rebuild the cached anchors from + // `context`. Returns only the suffix matches that are currently expandable. + std::vector> match( + const int32_t* context, + size_t len, + MatchState& state, + size_t total_len) const; + // Recompute all cached anchors from the current tail. After this, for every + // d in [1, min(len, max_trie_depth)], anchors[d - 1] represents the suffix of + // length d ending at context[len - 1]. + bool _rebuildMatchState(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; + // Advance the cached anchors by consuming the newly appended suffix one + // token at a time, without re-walking all suffixes from root. + bool _advanceMatchState( + MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; + // Check that every non-empty cached NodeRef in MatchState still resolves to + // the same logical trie node under the current trie_epoch_. + bool _validateMatchState(const MatchState& state) const; + // MatchState keeps all live suffix matches, including leaves. This helper + // filters the cached anchors down to the suffixes that currently have children and + // therefore can seed BFS / PROB draft construction. + std::vector> _getExpandableAnchors(const MatchState& state) const; + // Resolve a cached NodeRef back to a live trie node. nullptr means the + // cached location went stale and the caller should rebuild from context. + const TrieNode* resolve(const MatchState& state, const NodeRef& ref) const; + NodeRef rootRef() const { + return NodeRef{root_, root_->version}; + } + NodeRef capture(TrieNode* node) const { + if (node == nullptr) { + return {}; + } + return NodeRef{node, node->version}; + } + void retireNode(TrieNode* node) { + if (node != nullptr) { + ++node->version; + } + } TrieNode* getNode() { auto node = node_pool_[--free_node_count_]; + auto version = node->version; node->~TrieNode(); new (node) TrieNode(); + node->version = version; return node; } @@ -65,6 +139,7 @@ class Trie { TrieNode* root_; std::vector path_; Param param_; + uint64_t trie_epoch_ = 1; }; } // namespace ngram diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 8c108915c939..d106ceee19c0 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -123,13 +123,17 @@ def _prepare_draft_tokens( bs = batch.batch_size() self.ngram_corpus.synchronize() + req_ids = [] batch_tokens = [] + total_lens = [] for req in batch.reqs: check_token = self._efficient_concat_last_n( req.origin_input_ids, req.output_ids, self.max_trie_depth ) + req_ids.append(req.rid) batch_tokens.append(check_token) - req_drafts, mask = self.ngram_corpus.batch_get(batch_tokens) + total_lens.append(len(req.origin_input_ids) + len(req.output_ids)) + req_drafts, mask = self.ngram_corpus.batch_get(req_ids, batch_tokens, total_lens) total_draft_token_num = len(req_drafts) # Check if speculative decoding is needed; here we always enforce it @@ -261,6 +265,12 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul if batch.return_logprob: add_output_logprobs_for_spec_v1(batch, verify_input, logits_output) self._update_ngram_corpus(batch) + finished_req_ids = [] + for req in batch.reqs: + if req.finished() or req.is_retracted: + finished_req_ids.append(req.rid) + if finished_req_ids: + self.ngram_corpus.erase_match_state(finished_req_ids) batch.forward_mode = ForwardMode.DECODE else: diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 1605c23f1878..e07aa6528bb7 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -1,4 +1,5 @@ import unittest +import uuid import numpy as np @@ -22,6 +23,34 @@ def _make_corpus(match_type="BFS", **kwargs): return NgramCorpus(**defaults) +def _batch_get( + corpus: NgramCorpus, + batch_tokens: list[list[int]], +): + return corpus.batch_get( + req_ids=[uuid.uuid4().hex for _ in range(len(batch_tokens))], + batch_tokens=batch_tokens, + total_lens=[len(tokens) for tokens in batch_tokens], + ) + + +def _batch_get_with_state( + corpus: NgramCorpus, + req_id: str, + current_tokens: list[int], + total_len: int, +): + return corpus.batch_get([req_id], [current_tokens], [total_len]) + + +def _raw_batch_match(corpus: NgramCorpus, batch_tokens: list[list[int]]): + return corpus._ngram.batchMatch( + [uuid.uuid4().hex for _ in range(len(batch_tokens))], + batch_tokens, + [len(tokens) for tokens in batch_tokens], + ) + + SEED_SEQUENCES = [ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 44, 55, 66, 77, 88, 99, 100], @@ -116,7 +145,7 @@ def setUpClass(cls): cls.corpus = _make_corpus("BFS") cls.corpus.batch_put(SEED_SEQUENCES) cls.corpus.synchronize() - ids, masks = cls.corpus.batch_get(QUERY_SEQUENCES) + ids, masks = _batch_get(cls.corpus, QUERY_SEQUENCES) draft = 8 cls.ids = ids.reshape(-1, draft) cls.masks = masks.reshape(-1, draft, draft) @@ -142,7 +171,7 @@ def setUpClass(cls): cls.corpus = _make_corpus("PROB") cls.corpus.batch_put(SEED_SEQUENCES) cls.corpus.synchronize() - ids, masks = cls.corpus.batch_get(QUERY_SEQUENCES) + ids, masks = _batch_get(cls.corpus, QUERY_SEQUENCES) cls.ids = ids.reshape(-1, 8) cls.masks = masks.reshape(-1, 8, 8) @@ -166,7 +195,7 @@ def test_reset_produces_empty_results(self): corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - ids_before, _ = corpus.batch_get([[1, 2, 3]]) + ids_before, _ = _batch_get(corpus, [[1, 2, 3]]) self.assertTrue( any(t != 0 for t in ids_before.tolist()[1:]), "Expected non-trivial draft tokens before reset", @@ -174,7 +203,7 @@ def test_reset_produces_empty_results(self): corpus.reset() - ids_after, _ = corpus.batch_get([[1, 2, 3]]) + ids_after, _ = _batch_get(corpus, [[1, 2, 3]]) self.assertEqual( ids_after.tolist(), [3, 0, 0, 0, 0, 0, 0, 0], @@ -190,7 +219,7 @@ def test_unmatched_query(self): corpus.batch_put([[10, 20, 30, 40, 50]]) corpus.synchronize() - ids, masks = corpus.batch_get([[999, 888, 777]]) + ids, masks = _batch_get(corpus, [[999, 888, 777]]) ids_list = ids.tolist() self.assertEqual(ids_list[0], 777, "First token should be last context token") self.assertTrue( @@ -200,7 +229,7 @@ def test_unmatched_query(self): def test_empty_corpus(self): corpus = _make_corpus("BFS") - ids, masks = corpus.batch_get([[1, 2, 3]]) + ids, masks = _batch_get(corpus, [[1, 2, 3]]) ids_list = ids.tolist() self.assertEqual(ids_list[0], 3) self.assertTrue(all(t == 0 for t in ids_list[1:])) @@ -217,7 +246,7 @@ def test_incremental_inserts(self): corpus.batch_put([[1, 2, 3, 44, 55]]) corpus.synchronize() - ids, _ = corpus.batch_get([[1, 2, 3]]) + ids, _ = _batch_get(corpus, [[1, 2, 3]]) ids_list = ids.tolist() self.assertIn(4, ids_list, "Token 4 from first insert should still match") @@ -233,7 +262,7 @@ def test_small_capacity_does_not_crash(self): corpus.batch_put([long_seq]) corpus.synchronize() - ids, masks = corpus.batch_get([[50, 51, 52]]) + ids, masks = _batch_get(corpus, [[50, 51, 52]]) self.assertEqual(len(ids), 8, "Should still produce draft_token_num outputs") def test_eviction_preserves_recent(self): @@ -247,7 +276,7 @@ def test_eviction_preserves_recent(self): corpus.batch_put([recent_seq]) corpus.synchronize() - ids, _ = corpus.batch_get([[2000, 2001, 2002]]) + ids, _ = _batch_get(corpus, [[2000, 2001, 2002]]) ids_list = ids.tolist() self.assertEqual(ids_list[0], 2002, "Last context token should be first") self.assertIn(2003, ids_list, "Recent sequence should still be matchable") @@ -294,13 +323,13 @@ def test_batch_vs_individual(self): corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - batch_ids, batch_masks = corpus.batch_get(QUERY_SEQUENCES) + batch_ids, batch_masks = _batch_get(corpus, QUERY_SEQUENCES) draft = 8 batch_ids = batch_ids.reshape(-1, draft) batch_masks = batch_masks.reshape(-1, draft, draft) for i, query in enumerate(QUERY_SEQUENCES): - single_ids, single_masks = corpus.batch_get([query]) + single_ids, single_masks = _batch_get(corpus, [query]) single_ids = single_ids.reshape(-1, draft) single_masks = single_masks.reshape(-1, draft, draft) @@ -329,7 +358,7 @@ def test_bfs_mask_invariants(self): corpus = _make_corpus("BFS") corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - _, masks = corpus.batch_get(QUERY_SEQUENCES) + _, masks = _batch_get(corpus, QUERY_SEQUENCES) masks = masks.reshape(-1, 8, 8) for i in range(masks.shape[0]): self._check_mask(masks[i].tolist()) @@ -338,7 +367,7 @@ def test_prob_mask_invariants(self): corpus = _make_corpus("PROB") corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - _, masks = corpus.batch_get(QUERY_SEQUENCES) + _, masks = _batch_get(corpus, QUERY_SEQUENCES) masks = masks.reshape(-1, 8, 8) for i in range(masks.shape[0]): self._check_mask(masks[i].tolist()) @@ -362,7 +391,7 @@ def test_repeated_insert_promotes_token(self): corpus.batch_put([[1, 2, 3, 20, 21]]) corpus.synchronize() - ids, _ = corpus.batch_get([[1, 2, 3]]) + ids, _ = _batch_get(corpus, [[1, 2, 3]]) ids_list = ids.tolist() self.assertEqual( @@ -388,7 +417,7 @@ def test_most_recent_insert_selected(self): corpus.batch_put([[1, 2, 3, 20, 21]]) corpus.synchronize() - ids, _ = corpus.batch_get([[1, 2, 3]]) + ids, _ = _batch_get(corpus, [[1, 2, 3]]) ids_list = ids.tolist() self.assertEqual( ids_list[1], @@ -406,7 +435,7 @@ def test_shared_suffix_both_match(self): corpus.batch_put([[300, 400, 7, 8, 9, 60, 61]]) corpus.synchronize() - ids, _ = corpus.batch_get([[7, 8, 9]]) + ids, _ = _batch_get(corpus, [[7, 8, 9]]) ids_list = ids.tolist() self.assertIn(50, ids_list, "Continuation from first sequence missing") self.assertIn(60, ids_list, "Continuation from second sequence missing") @@ -420,7 +449,7 @@ def test_single_token_query(self): corpus.batch_put([[5, 10, 20, 30]]) corpus.synchronize() - ids, masks = corpus.batch_get([[5]]) + ids, masks = _batch_get(corpus, [[5]]) ids_list = ids.tolist() self.assertEqual(ids_list[0], 5, "First token should be last context token") self.assertIn(10, ids_list, "Should match continuation after single token 5") @@ -436,7 +465,7 @@ def test_context_longer_than_max_trie_depth(self): corpus.synchronize() long_query = list(range(1, 16)) - ids, masks = corpus.batch_get([long_query]) + ids, masks = _batch_get(corpus, [long_query]) ids_list = ids.tolist() self.assertEqual(ids_list[0], 15, "First token should be last context token") self.assertIn(16, ids_list, "Should match via suffix despite long context") @@ -447,7 +476,7 @@ def test_matches_longest_stored_suffix(self): corpus.batch_put([[99, 3, 4, 5, 6, 8]]) corpus.synchronize() - ids, _ = corpus.batch_get([[2, 3, 4, 5, 6]]) + ids, _ = _batch_get(corpus, [[2, 3, 4, 5, 6]]) ids_list = ids.tolist() self.assertIn( 7, ids_list, "Longest stored suffix should contribute a continuation" @@ -468,7 +497,7 @@ def test_full_budget_used(self): corpus.batch_put([seq]) corpus.synchronize() - ids, _ = corpus.batch_get([[1, 2, 3]]) + ids, _ = _batch_get(corpus, [[1, 2, 3]]) ids_list = ids.tolist() self.assertEqual(len(ids_list), 8) non_zero = [t for t in ids_list[1:] if t != 0] @@ -487,7 +516,7 @@ def test_truncate_reduces_output(self): corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) + result = _raw_batch_match(corpus, [[1, 2, 3]]) original_len = len(result.token) self.assertEqual(original_len, 8) @@ -500,12 +529,12 @@ def test_truncate_preserves_mask_structure(self): corpus.batch_put(SEED_SEQUENCES) corpus.synchronize() - result = corpus._ngram.batchMatch([[1, 2, 3]]) + result = _raw_batch_match(corpus, [[1, 2, 3]]) full_ids = list(result.token) full_mask = list(result.mask) n = len(full_ids) - result_copy = corpus._ngram.batchMatch([[1, 2, 3]]) + result_copy = _raw_batch_match(corpus, [[1, 2, 3]]) trunc_n = 4 result_copy.truncate(trunc_n) trunc_mask = list(result_copy.mask) @@ -532,14 +561,14 @@ def test_reset_then_reinsert(self): corpus.batch_put([[10, 20, 30, 40, 50]]) corpus.synchronize() - ids_old, _ = corpus.batch_get([[1, 2, 3]]) + ids_old, _ = _batch_get(corpus, [[1, 2, 3]]) ids_old_list = ids_old.tolist() self.assertTrue( all(t == 0 for t in ids_old_list[1:]), f"Old data should not match after reset+reinsert, got {ids_old_list}", ) - ids_new, _ = corpus.batch_get([[10, 20, 30]]) + ids_new, _ = _batch_get(corpus, [[10, 20, 30]]) ids_new_list = ids_new.tolist() self.assertEqual(ids_new_list[0], 30) self.assertIn(40, ids_new_list, "New data should match after reset+reinsert") @@ -555,7 +584,7 @@ def test_old_data_evicted(self): corpus.batch_put([old_seq]) corpus.synchronize() - ids_before, _ = corpus.batch_get([[5000, 5001, 5002]]) + ids_before, _ = _batch_get(corpus, [[5000, 5001, 5002]]) self.assertIn( 5003, ids_before.tolist(), @@ -567,7 +596,7 @@ def test_old_data_evicted(self): corpus.batch_put([new_seq]) corpus.synchronize() - ids_after, _ = corpus.batch_get([[5000, 5001, 5002]]) + ids_after, _ = _batch_get(corpus, [[5000, 5001, 5002]]) ids_after_list = ids_after.tolist() self.assertNotIn( 5003, @@ -576,5 +605,78 @@ def test_old_data_evicted(self): ) +class TestNgramCorpusIncremental(CustomTestCase): + """Verify the incremental matching path matches the stateless path.""" + + def _assert_incremental_matches_stateless(self, match_type: str): + corpus = _make_corpus(match_type, max_trie_depth=4, draft_token_num=4) + corpus.batch_put([[1, 2, 3, 4, 5, 6], [9, 3, 4, 7, 8]]) + corpus.synchronize() + + req_id = f"req-{match_type.lower()}" + + steps = [ + [1, 2, 3], + [1, 2, 3, 4], + [1, 2, 3, 4, 5, 6], + ] + for full_sequence in steps: + current_tail = full_sequence[-4:] + inc_ids, inc_masks = _batch_get_with_state( + corpus, + req_id, + current_tail, + len(full_sequence), + ) + full_ids, full_masks = _batch_get(corpus, [current_tail]) + np.testing.assert_array_equal(inc_ids, full_ids) + np.testing.assert_array_equal(inc_masks, full_masks) + + def test_incremental_matches_stateless_bfs(self): + self._assert_incremental_matches_stateless("BFS") + + def test_incremental_matches_stateless_prob(self): + self._assert_incremental_matches_stateless("PROB") + + def test_leaf_anchor_becomes_expandable(self): + corpus = _make_corpus("BFS", max_trie_depth=4, draft_token_num=4) + corpus.batch_put([[1, 2, 3]]) + corpus.synchronize() + + req_id = "leaf-anchor" + ids_before, _ = _batch_get_with_state(corpus, req_id, [2, 3], 2) + self.assertTrue( + all(t == 0 for t in ids_before.tolist()[1:]), + f"Expected only the last token before extension, got {ids_before.tolist()}", + ) + + corpus.batch_put([[9, 2, 3, 4]]) + corpus.synchronize() + + inc_ids, inc_masks = _batch_get_with_state(corpus, req_id, [2, 3], 2) + full_ids, full_masks = _batch_get(corpus, [[2, 3]]) + np.testing.assert_array_equal(inc_ids, full_ids) + np.testing.assert_array_equal(inc_masks, full_masks) + self.assertIn(4, inc_ids.tolist(), f"Expected token 4 after extension, got {inc_ids.tolist()}") + + def test_stale_state_rebuilds_after_eviction(self): + corpus = _make_corpus("BFS", capacity=150, max_trie_depth=6, draft_token_num=4) + corpus.batch_put([list(range(5000, 5030))]) + corpus.synchronize() + + req_id = "evicted" + _batch_get_with_state(corpus, req_id, [5000, 5001, 5002], 3) + + for i in range(5): + new_seq = list(range(6000 + i * 30, 6000 + i * 30 + 30)) + corpus.batch_put([new_seq]) + corpus.synchronize() + + inc_ids, inc_masks = _batch_get_with_state(corpus, req_id, [5000, 5001, 5002], 3) + full_ids, full_masks = _batch_get(corpus, [[5000, 5001, 5002]]) + np.testing.assert_array_equal(inc_ids, full_ids) + np.testing.assert_array_equal(inc_masks, full_masks) + + if __name__ == "__main__": unittest.main(verbosity=3) From 19510022168a67cc7daea76a2074da09bb323508 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 24 Mar 2026 00:28:14 +0000 Subject: [PATCH 5/7] lint --- .../srt/speculative/cpp_ngram/ngram.cpp | 20 +++---------------- .../sglang/srt/speculative/cpp_ngram/trie.cpp | 7 ++----- .../sglang/srt/speculative/cpp_ngram/trie.h | 10 +++------- python/sglang/srt/speculative/ngram_worker.py | 4 +++- .../spec/utils/test_ngram_corpus.py | 10 ++++++++-- 5 files changed, 19 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp index dd9675bc4c8f..b88f6801849d 100644 --- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp @@ -83,20 +83,12 @@ Result Ngram::batchMatch( const std::vector>& tokens, const std::vector& total_lens) { if (req_ids.size() != tokens.size() || req_ids.size() != total_lens.size()) { - throw std::runtime_error( - "batchMatch expects req_ids, tokens, and total_lens to match in size"); + throw std::runtime_error("batchMatch expects req_ids, tokens, and total_lens to match in size"); } std::unique_lock lock(mutex_); - using BuildFn = Result (Trie::*)( - const int32_t*, - size_t, - int32_t, - size_t, - const Param&, - MatchState&, - size_t) const; + using BuildFn = Result (Trie::*)(const int32_t*, size_t, int32_t, size_t, const Param&, MatchState&, size_t) const; BuildFn build_fn; if (param_.match_type == "BFS") { build_fn = &Trie::buildRecency; @@ -116,13 +108,7 @@ Result Ngram::batchMatch( auto& state = match_state_[req_ids[i]]; auto draft_token_num = param_.get_draft_token_num(tokens.size()); auto res = (trie_.get()->*build_fn)( - suffix.data(), - suffix.size(), - suffix.back(), - draft_token_num, - param_, - state, - total_lens[i]); + suffix.data(), suffix.size(), suffix.back(), draft_token_num, param_, state, total_lens[i]); merged.token.insert(merged.token.end(), res.token.begin(), res.token.end()); merged.mask.insert(merged.mask.end(), res.mask.begin(), res.mask.end()); } diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/srt/speculative/cpp_ngram/trie.cpp index 47f7cf0a3b6a..7e9292c571ee 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/trie.cpp @@ -197,11 +197,8 @@ std::vector> Trie::_getExpandableAnchors(con return result; } -std::vector> Trie::match( - const int32_t* context, - size_t len, - MatchState& state, - size_t total_len) const { +std::vector> +Trie::match(const int32_t* context, size_t len, MatchState& state, size_t total_len) const { const bool has_forward_progress = total_len >= state.processed_total_len; const auto appended_len = has_forward_progress ? total_len - state.processed_total_len : 0; const auto expected_prev_depth = std::min(state.processed_total_len, param_.max_trie_depth); diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/srt/speculative/cpp_ngram/trie.h index 565ef1dfad37..76f156c14d07 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/srt/speculative/cpp_ngram/trie.h @@ -85,19 +85,15 @@ class Trie { // this request, infer the newly appended suffix from (`context`, `total_len`) // and advance anchors incrementally; otherwise rebuild the cached anchors from // `context`. Returns only the suffix matches that are currently expandable. - std::vector> match( - const int32_t* context, - size_t len, - MatchState& state, - size_t total_len) const; + std::vector> + match(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; // Recompute all cached anchors from the current tail. After this, for every // d in [1, min(len, max_trie_depth)], anchors[d - 1] represents the suffix of // length d ending at context[len - 1]. bool _rebuildMatchState(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; // Advance the cached anchors by consuming the newly appended suffix one // token at a time, without re-walking all suffixes from root. - bool _advanceMatchState( - MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; + bool _advanceMatchState(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; // Check that every non-empty cached NodeRef in MatchState still resolves to // the same logical trie node under the current trie_epoch_. bool _validateMatchState(const MatchState& state) const; diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index d106ceee19c0..3b184d9466fb 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -133,7 +133,9 @@ def _prepare_draft_tokens( req_ids.append(req.rid) batch_tokens.append(check_token) total_lens.append(len(req.origin_input_ids) + len(req.output_ids)) - req_drafts, mask = self.ngram_corpus.batch_get(req_ids, batch_tokens, total_lens) + req_drafts, mask = self.ngram_corpus.batch_get( + req_ids, batch_tokens, total_lens + ) total_draft_token_num = len(req_drafts) # Check if speculative decoding is needed; here we always enforce it diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index e07aa6528bb7..e7c5ce406de9 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -657,7 +657,11 @@ def test_leaf_anchor_becomes_expandable(self): full_ids, full_masks = _batch_get(corpus, [[2, 3]]) np.testing.assert_array_equal(inc_ids, full_ids) np.testing.assert_array_equal(inc_masks, full_masks) - self.assertIn(4, inc_ids.tolist(), f"Expected token 4 after extension, got {inc_ids.tolist()}") + self.assertIn( + 4, + inc_ids.tolist(), + f"Expected token 4 after extension, got {inc_ids.tolist()}", + ) def test_stale_state_rebuilds_after_eviction(self): corpus = _make_corpus("BFS", capacity=150, max_trie_depth=6, draft_token_num=4) @@ -672,7 +676,9 @@ def test_stale_state_rebuilds_after_eviction(self): corpus.batch_put([new_seq]) corpus.synchronize() - inc_ids, inc_masks = _batch_get_with_state(corpus, req_id, [5000, 5001, 5002], 3) + inc_ids, inc_masks = _batch_get_with_state( + corpus, req_id, [5000, 5001, 5002], 3 + ) full_ids, full_masks = _batch_get(corpus, [[5000, 5001, 5002]]) np.testing.assert_array_equal(inc_ids, full_ids) np.testing.assert_array_equal(inc_masks, full_masks) From 0e2d34906c7399b579f94ff2427229d6b47ab8c6 Mon Sep 17 00:00:00 2001 From: kpham-sgl Date: Tue, 24 Mar 2026 00:38:08 +0000 Subject: [PATCH 6/7] nit --- .../sglang/srt/speculative/cpp_ngram/trie.cpp | 19 +++++++++---------- .../sglang/srt/speculative/cpp_ngram/trie.h | 8 ++++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.cpp b/python/sglang/srt/speculative/cpp_ngram/trie.cpp index 7e9292c571ee..a429a5576933 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/trie.cpp @@ -109,7 +109,7 @@ const TrieNode* Trie::resolve(const MatchState& state, const NodeRef& ref) const return ref.ptr; } -bool Trie::_validateMatchState(const MatchState& state) const { +bool Trie::validateMatchState_(const MatchState& state) const { if (state.trie_epoch != trie_epoch_) { return false; } @@ -121,7 +121,7 @@ bool Trie::_validateMatchState(const MatchState& state) const { return true; } -bool Trie::_rebuildMatchState(const int32_t* context, size_t len, MatchState& state, size_t total_len) const { +void Trie::rebuildMatchState_(const int32_t* context, size_t len, MatchState& state, size_t total_len) const { const auto max_match_depth = std::min(len, param_.max_trie_depth); state.trie_epoch = trie_epoch_; state.processed_total_len = total_len; @@ -143,11 +143,10 @@ bool Trie::_rebuildMatchState(const int32_t* context, size_t len, MatchState& st state.anchors[match_depth - 1] = capture(cursor); } } - return true; } -bool Trie::_advanceMatchState(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const { - if (!_validateMatchState(state)) { +bool Trie::advanceMatchState_(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const { + if (!validateMatchState_(state)) { return false; } @@ -185,7 +184,7 @@ bool Trie::_advanceMatchState(MatchState& state, const int32_t* tokens, size_t l return true; } -std::vector> Trie::_getExpandableAnchors(const MatchState& state) const { +std::vector> Trie::getExpandableAnchors_(const MatchState& state) const { std::vector> result; result.reserve(state.anchors.size()); for (size_t depth = state.anchors.size(); depth > 0; --depth) { @@ -205,12 +204,12 @@ Trie::match(const int32_t* context, size_t len, MatchState& state, size_t total_ const bool can_advance = state.trie_epoch == trie_epoch_ && has_forward_progress && appended_len <= len && state.anchors.size() == expected_prev_depth; - if (can_advance && _advanceMatchState(state, context + len - appended_len, appended_len, total_len)) { - return _getExpandableAnchors(state); + if (can_advance && advanceMatchState_(state, context + len - appended_len, appended_len, total_len)) { + return getExpandableAnchors_(state); } - _rebuildMatchState(context, len, state, total_len); - return _getExpandableAnchors(state); + rebuildMatchState_(context, len, state, total_len); + return getExpandableAnchors_(state); } Result Trie::buildRecency( diff --git a/python/sglang/srt/speculative/cpp_ngram/trie.h b/python/sglang/srt/speculative/cpp_ngram/trie.h index 76f156c14d07..954e508c4927 100644 --- a/python/sglang/srt/speculative/cpp_ngram/trie.h +++ b/python/sglang/srt/speculative/cpp_ngram/trie.h @@ -90,17 +90,17 @@ class Trie { // Recompute all cached anchors from the current tail. After this, for every // d in [1, min(len, max_trie_depth)], anchors[d - 1] represents the suffix of // length d ending at context[len - 1]. - bool _rebuildMatchState(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; + void rebuildMatchState_(const int32_t* context, size_t len, MatchState& state, size_t total_len) const; // Advance the cached anchors by consuming the newly appended suffix one // token at a time, without re-walking all suffixes from root. - bool _advanceMatchState(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; + bool advanceMatchState_(MatchState& state, const int32_t* tokens, size_t len, size_t total_len) const; // Check that every non-empty cached NodeRef in MatchState still resolves to // the same logical trie node under the current trie_epoch_. - bool _validateMatchState(const MatchState& state) const; + bool validateMatchState_(const MatchState& state) const; // MatchState keeps all live suffix matches, including leaves. This helper // filters the cached anchors down to the suffixes that currently have children and // therefore can seed BFS / PROB draft construction. - std::vector> _getExpandableAnchors(const MatchState& state) const; + std::vector> getExpandableAnchors_(const MatchState& state) const; // Resolve a cached NodeRef back to a live trie node. nullptr means the // cached location went stale and the caller should rebuild from context. const TrieNode* resolve(const MatchState& state, const NodeRef& ref) const; From 67140c5da78c4558f41989a0e8926e15a20ae014 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sun, 5 Apr 2026 21:50:18 -0700 Subject: [PATCH 7/7] tiny fix lint --- .../sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp | 4 +--- test/registered/spec/utils/test_ngram_corpus.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp index 486eb146f97a..bbd815c51c85 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/ngram_corpus_ffi.cpp @@ -107,9 +107,7 @@ struct NgramCorpusObj : public tvm::ffi::Object { private: void write_result_( - const ngram::Result& result, - const tvm::ffi::TensorView& out_tokens, - const tvm::ffi::TensorView& out_mask) { + const ngram::Result& result, const tvm::ffi::TensorView& out_tokens, const tvm::ffi::TensorView& out_mask) { auto* out_tok = static_cast(out_tokens.data_ptr()); auto* out_msk = static_cast(out_mask.data_ptr()); if (result.token.size() > static_cast(out_tokens.size(0))) { diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/spec/utils/test_ngram_corpus.py index 1ab9c886a5be..f0169306df61 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/spec/utils/test_ngram_corpus.py @@ -526,7 +526,6 @@ def test_truncate_preserves_mask_structure(self): n = 8 full_mask = masks.reshape(n, n) - trunc_n = 4 trunc_mask = full_mask[:trunc_n, :trunc_n]