diff --git a/.github/workflows/rerun-test.yml b/.github/workflows/rerun-test.yml index 84a4537865d4..431b69474c1c 100644 --- a/.github/workflows/rerun-test.yml +++ b/.github/workflows/rerun-test.yml @@ -84,11 +84,31 @@ jobs: source /etc/profile.d/sglang-ci.sh fi cd test/ - echo "${{ inputs.test_command }}" | while IFS= read -r cmd; do + # Collect non-empty commands into an array for counting. + cmds=() + while IFS= read -r cmd; do [ -z "$cmd" ] && continue - echo ">>> Running: python3 $cmd" + cmds+=("$cmd") + done <<< "${{ inputs.test_command }}" + total=${#cmds[@]} + suite_start=$SECONDS + for idx in "${!cmds[@]}"; do + i=$((idx + 1)) + cmd="${cmds[$idx]}" + echo "" + echo "." + echo "Begin ($i/$total): python3 $cmd" + echo "." + file_start=$SECONDS python3 $cmd -f || exit 1 + elapsed=$(( SECONDS - file_start )) + echo "." + echo "End ($i/$total): elapsed=${elapsed}s" + echo "." + echo "" done + total_elapsed=$(( SECONDS - suite_start )) + echo "All $total test(s) passed in ${total_elapsed}s" - uses: ./.github/actions/upload-cuda-coredumps if: always() @@ -129,8 +149,28 @@ jobs: timeout-minutes: 60 run: | cd test/ - echo "${{ inputs.test_command }}" | while IFS= read -r cmd; do + # Collect non-empty commands into an array for counting. + cmds=() + while IFS= read -r cmd; do [ -z "$cmd" ] && continue - echo ">>> Running: python3 $cmd" + cmds+=("$cmd") + done <<< "${{ inputs.test_command }}" + total=${#cmds[@]} + suite_start=$SECONDS + for idx in "${!cmds[@]}"; do + i=$((idx + 1)) + cmd="${cmds[$idx]}" + echo "" + echo "." + echo "Begin ($i/$total): python3 $cmd" + echo "." + file_start=$SECONDS python3 $cmd -f || exit 1 + elapsed=$(( SECONDS - file_start )) + echo "." + echo "End ($i/$total): elapsed=${elapsed}s" + echo "." + echo "" done + total_elapsed=$(( SECONDS - suite_start )) + echo "All $total test(s) passed in ${total_elapsed}s" diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp index a429a5576933..1cbb5eef55a0 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp @@ -91,6 +91,8 @@ void Trie::squeeze(size_t count) { } void Trie::reset() { + // Epoch bump invalidates all cached MatchState objects, so we do not need to + // retireNode() on every node individually. ++trie_epoch_; global_lru_.clear(); path_.clear(); @@ -150,16 +152,17 @@ bool Trie::advanceMatchState_(MatchState& state, const int32_t* tokens, size_t l return false; } + // Reuse a single buffer across iterations to avoid per-token heap allocation. + std::vector next; + next.reserve(param_.max_trie_depth); + for (size_t i = 0; i < len; ++i) { const auto next_depth = std::min(state.anchors.size() + 1, param_.max_trie_depth); - std::vector next(next_depth); + next.assign(next_depth, {}); - const auto root_ref = rootRef(); - const auto root = resolve(state, root_ref); - if (root == nullptr) { - return false; - } - if (auto iter = root->child.find(tokens[i]); iter != root->child.end()) { + // Root is never evicted, so we access it directly; the epoch was already + // validated above. + if (auto iter = root_->child.find(tokens[i]); iter != root_->child.end()) { next[0] = capture(iter->second); } diff --git a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h index 909ffae8fa27..76707eea1e89 100644 --- a/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h +++ b/python/sglang/jit_kernel/csrc/ngram_corpus/trie.h @@ -24,6 +24,8 @@ struct TrieNode { int32_t freq = 0; // Logical generation of this TrieNode. retireNode() bumps it before the node // goes back to the pool so stale NodeRefs fail validation after reuse. + // Starts at 1 so that a default-constructed NodeRef (version=0) never + // accidentally resolves to a live node. uint64_t version = 1; struct CompareByFreq { diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index bf332ebeb091..cbc9586601ca 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -269,6 +269,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul if batch.return_logprob: add_output_logprobs_for_spec_v1(batch, verify_input, logits_output) self._update_ngram_corpus(batch) + # Clean up per-request match state for finished/retracted requests. + # State entries are created in _prepare_draft_tokens and cleaned here. + # If a request is removed without passing through verify, the entry + # persists until reset(); this is acceptable because MatchState is small. finished_req_ids = [] for req in batch.reqs: if req.finished() or req.is_retracted: diff --git a/test/registered/spec/utils/test_ngram_corpus.py b/test/registered/unit/spec/test_ngram_corpus.py similarity index 89% rename from test/registered/spec/utils/test_ngram_corpus.py rename to test/registered/unit/spec/test_ngram_corpus.py index f0169306df61..94a15718391f 100644 --- a/test/registered/spec/utils/test_ngram_corpus.py +++ b/test/registered/unit/spec/test_ngram_corpus.py @@ -4,10 +4,10 @@ import numpy as np from sglang.srt.speculative.cpp_ngram.ngram_corpus import NgramCorpus -from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.ci.ci_register import register_cpu_ci from sglang.test.test_utils import CustomTestCase -register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small") +register_cpu_ci(est_time=10, suite="stage-a-test-cpu") def _make_corpus(match_type="BFS", **kwargs): @@ -674,5 +674,69 @@ def test_stale_state_rebuilds_after_eviction(self): np.testing.assert_array_equal(inc_masks, full_masks) +class TestNgramCorpusMatchBenchmark(CustomTestCase): + """Benchmark incremental advance vs full rebuild in match().""" + + def test_incremental_faster_than_rebuild(self): + """Incremental advance (O(D) per token) should be faster than rebuild (O(D^2)).""" + import time + + max_trie_depth = 18 + draft_token_num = 8 + corpus = _make_corpus( + "BFS", + max_trie_depth=max_trie_depth, + draft_token_num=draft_token_num, + capacity=500000, + ) + + # Seed the trie with diverse sequences so suffix matching is non-trivial. + seed_data = [list(range(i, i + 50)) for i in range(0, 5000, 50)] + corpus.batch_put(seed_data) + corpus.synchronize() + + num_steps = 500 + base_seq = list(range(1, max_trie_depth + 1)) + + # --- Incremental path: same req_id, total_len grows by 1 each step --- + req_id = "bench-incremental" + # Warm up the state with the initial context. + _batch_get_with_state(corpus, req_id, base_seq, len(base_seq)) + + start = time.perf_counter() + for step in range(num_steps): + total_len = len(base_seq) + step + 1 + new_token = (step + max_trie_depth + 1) % 5000 + tail = (base_seq + [new_token])[-max_trie_depth:] + base_seq = tail + _batch_get_with_state(corpus, req_id, tail, total_len) + incremental_us = (time.perf_counter() - start) / num_steps * 1e6 + + # --- Rebuild path: unique req_id each call forces fresh state --- + base_seq = list(range(1, max_trie_depth + 1)) + start = time.perf_counter() + for step in range(num_steps): + new_token = (step + max_trie_depth + 1) % 5000 + tail = (base_seq + [new_token])[-max_trie_depth:] + base_seq = tail + _batch_get(corpus, [tail]) + rebuild_us = (time.perf_counter() - start) / num_steps * 1e6 + + print( + f"\n Incremental: {incremental_us:.1f} us/step" + f"\n Rebuild: {rebuild_us:.1f} us/step" + f"\n Speedup: {rebuild_us / incremental_us:.2f}x" + ) + + # The incremental path should be at least as fast; allow a small margin + # for noise. With D=12 the theoretical speedup is ~12x (D^2/D). + self.assertLess( + incremental_us, + rebuild_us * 1.1, + f"Incremental ({incremental_us:.1f} us) should not be slower than " + f"rebuild ({rebuild_us:.1f} us)", + ) + + if __name__ == "__main__": unittest.main(verbosity=3)