Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions .github/workflows/rerun-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,31 @@ jobs:
source /etc/profile.d/sglang-ci.sh
fi
cd test/
echo "${{ inputs.test_command }}" | while IFS= read -r cmd; do
# Collect non-empty commands into an array for counting.
cmds=()
while IFS= read -r cmd; do
[ -z "$cmd" ] && continue
echo ">>> Running: python3 $cmd"
cmds+=("$cmd")
done <<< "${{ inputs.test_command }}"
total=${#cmds[@]}
suite_start=$SECONDS
for idx in "${!cmds[@]}"; do
i=$((idx + 1))
cmd="${cmds[$idx]}"
echo ""
echo "."
echo "Begin ($i/$total): python3 $cmd"
echo "."
file_start=$SECONDS
python3 $cmd -f || exit 1
elapsed=$(( SECONDS - file_start ))
echo "."
echo "End ($i/$total): elapsed=${elapsed}s"
echo "."
echo ""
done
total_elapsed=$(( SECONDS - suite_start ))
echo "All $total test(s) passed in ${total_elapsed}s"

- uses: ./.github/actions/upload-cuda-coredumps
if: always()
Expand Down Expand Up @@ -129,8 +149,28 @@ jobs:
timeout-minutes: 60
run: |
cd test/
echo "${{ inputs.test_command }}" | while IFS= read -r cmd; do
# Collect non-empty commands into an array for counting.
cmds=()
while IFS= read -r cmd; do
[ -z "$cmd" ] && continue
echo ">>> Running: python3 $cmd"
cmds+=("$cmd")
done <<< "${{ inputs.test_command }}"
total=${#cmds[@]}
suite_start=$SECONDS
for idx in "${!cmds[@]}"; do
i=$((idx + 1))
cmd="${cmds[$idx]}"
echo ""
echo "."
echo "Begin ($i/$total): python3 $cmd"
echo "."
file_start=$SECONDS
python3 $cmd -f || exit 1
elapsed=$(( SECONDS - file_start ))
echo "."
echo "End ($i/$total): elapsed=${elapsed}s"
echo "."
echo ""
done
total_elapsed=$(( SECONDS - suite_start ))
echo "All $total test(s) passed in ${total_elapsed}s"
17 changes: 10 additions & 7 deletions python/sglang/jit_kernel/csrc/ngram_corpus/trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ void Trie::squeeze(size_t count) {
}

void Trie::reset() {
// Epoch bump invalidates all cached MatchState objects, so we do not need to
// retireNode() on every node individually.
++trie_epoch_;
global_lru_.clear();
path_.clear();
Expand Down Expand Up @@ -150,16 +152,17 @@ bool Trie::advanceMatchState_(MatchState& state, const int32_t* tokens, size_t l
return false;
}

// Reuse a single buffer across iterations to avoid per-token heap allocation.
std::vector<NodeRef> next;
next.reserve(param_.max_trie_depth);

for (size_t i = 0; i < len; ++i) {
const auto next_depth = std::min(state.anchors.size() + 1, param_.max_trie_depth);
std::vector<NodeRef> next(next_depth);
next.assign(next_depth, {});

const auto root_ref = rootRef();
const auto root = resolve(state, root_ref);
if (root == nullptr) {
return false;
}
if (auto iter = root->child.find(tokens[i]); iter != root->child.end()) {
// Root is never evicted, so we access it directly; the epoch was already
// validated above.
if (auto iter = root_->child.find(tokens[i]); iter != root_->child.end()) {
next[0] = capture(iter->second);
}

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/jit_kernel/csrc/ngram_corpus/trie.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct TrieNode {
int32_t freq = 0;
// Logical generation of this TrieNode. retireNode() bumps it before the node
// goes back to the pool so stale NodeRefs fail validation after reuse.
// Starts at 1 so that a default-constructed NodeRef (version=0) never
// accidentally resolves to a live node.
uint64_t version = 1;

struct CompareByFreq {
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/speculative/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
if batch.return_logprob:
add_output_logprobs_for_spec_v1(batch, verify_input, logits_output)
self._update_ngram_corpus(batch)
# Clean up per-request match state for finished/retracted requests.
# State entries are created in _prepare_draft_tokens and cleaned here.
# If a request is removed without passing through verify, the entry
# persists until reset(); this is acceptable because MatchState is small.
finished_req_ids = []
for req in batch.reqs:
if req.finished() or req.is_retracted:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np

from sglang.srt.speculative.cpp_ngram.ngram_corpus import NgramCorpus
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.ci.ci_register import register_cpu_ci
from sglang.test.test_utils import CustomTestCase

register_cuda_ci(est_time=8, suite="stage-b-test-1-gpu-small")
register_cpu_ci(est_time=10, suite="stage-a-test-cpu")


def _make_corpus(match_type="BFS", **kwargs):
Expand Down Expand Up @@ -674,5 +674,69 @@ def test_stale_state_rebuilds_after_eviction(self):
np.testing.assert_array_equal(inc_masks, full_masks)


class TestNgramCorpusMatchBenchmark(CustomTestCase):
"""Benchmark incremental advance vs full rebuild in match()."""

def test_incremental_faster_than_rebuild(self):
"""Incremental advance (O(D) per token) should be faster than rebuild (O(D^2))."""
import time

max_trie_depth = 18
draft_token_num = 8
corpus = _make_corpus(
"BFS",
max_trie_depth=max_trie_depth,
draft_token_num=draft_token_num,
capacity=500000,
)

# Seed the trie with diverse sequences so suffix matching is non-trivial.
seed_data = [list(range(i, i + 50)) for i in range(0, 5000, 50)]
corpus.batch_put(seed_data)
corpus.synchronize()

num_steps = 500
base_seq = list(range(1, max_trie_depth + 1))

# --- Incremental path: same req_id, total_len grows by 1 each step ---
req_id = "bench-incremental"
# Warm up the state with the initial context.
_batch_get_with_state(corpus, req_id, base_seq, len(base_seq))

start = time.perf_counter()
for step in range(num_steps):
total_len = len(base_seq) + step + 1
new_token = (step + max_trie_depth + 1) % 5000
tail = (base_seq + [new_token])[-max_trie_depth:]
base_seq = tail
_batch_get_with_state(corpus, req_id, tail, total_len)
incremental_us = (time.perf_counter() - start) / num_steps * 1e6

# --- Rebuild path: unique req_id each call forces fresh state ---
base_seq = list(range(1, max_trie_depth + 1))
start = time.perf_counter()
for step in range(num_steps):
new_token = (step + max_trie_depth + 1) % 5000
tail = (base_seq + [new_token])[-max_trie_depth:]
base_seq = tail
_batch_get(corpus, [tail])
rebuild_us = (time.perf_counter() - start) / num_steps * 1e6

print(
f"\n Incremental: {incremental_us:.1f} us/step"
f"\n Rebuild: {rebuild_us:.1f} us/step"
f"\n Speedup: {rebuild_us / incremental_us:.2f}x"
)

# The incremental path should be at least as fast; allow a small margin
# for noise. With D=12 the theoretical speedup is ~12x (D^2/D).
self.assertLess(
incremental_us,
rebuild_us * 1.1,
f"Incremental ({incremental_us:.1f} us) should not be slower than "
f"rebuild ({rebuild_us:.1f} us)",
)


if __name__ == "__main__":
unittest.main(verbosity=3)
Loading