Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 72 additions & 80 deletions test/registered/spec/test_ngram_speculative_decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import json
import os
import tempfile
import unittest

import requests
Expand Down Expand Up @@ -35,21 +32,6 @@
0.8,
]

EXTERNAL_SAM_CORPUS_RECORDS = [
"The capital of France is Paris.",
"The answer to life, the universe, and everything is 42.",
]


def _safe_remove(path: str):
if os.path.exists(path):
os.remove(path)


def _safe_kill_process(process):
if process is not None and process.poll() is None:
kill_process_tree(process.pid)


class TestNgramSpeculativeDecodingBase(GSM8KMixin, CustomTestCase):
model = DEFAULT_TARGET_MODEL_NGRAM
Expand Down Expand Up @@ -91,7 +73,78 @@ def get_server_args(cls):
class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
return DEFAULT_SERVER_ARGS + [
"--attention-backend",
"flashinfer",
"--speculative-ngram-external-sam-budget",
"8",
]

def test_output_as_corpus_boosts_accept_length(self):
"""Baseline → HTTP add corpus → verify accept length boost."""
prompts = [
"The capital of France is",
"In mathematics, the Pythagorean theorem states that",
"The speed of light in a vacuum is approximately",
"Water boils at a temperature of",
"The largest planet in our solar system is",
]
max_new_tokens = 128
num_rounds = 3

def generate_batch():
outputs = []
for prompt in prompts:
resp = requests.post(
self.base_url + "/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
},
timeout=120,
)
self.assertEqual(resp.status_code, 200, resp.text)
outputs.append(resp.json()["text"])
return outputs

def get_accept_length():
info = requests.get(self.base_url + "/get_server_info").json()
return info["internal_states"][0]["avg_spec_accept_length"]

# Phase 1: baseline — no SAM corpus loaded, only trie
generated_outputs = []
for _ in range(num_rounds):
generated_outputs = generate_batch()
baseline_accept_len = get_accept_length()
print(f"\n Baseline accept length (no SAM): {baseline_accept_len:.2f}")

# Flush cache so phase 2 starts clean
requests.post(self.base_url + "/flush_cache", timeout=30)

# Phase 2: add generated outputs as corpus via HTTP API
resp = requests.post(
self.base_url + "/add_external_corpus",
json={"corpus_id": "bench", "documents": generated_outputs},
timeout=120,
)
self.assertEqual(resp.status_code, 200, resp.text)
self.assertTrue(resp.json()["success"], resp.json().get("message"))

for _ in range(num_rounds):
generate_batch()
sam_accept_len = get_accept_length()
print(f" SAM accept length (output as corpus): {sam_accept_len:.2f}")
print(f" Speedup: {sam_accept_len / baseline_accept_len:.2f}x")

self.assertGreater(
sam_accept_len,
baseline_accept_len * 2.0,
f"SAM accept length ({sam_accept_len:.2f}) should be at least 2x "
f"baseline ({baseline_accept_len:.2f}) when corpus matches output",
)


class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase):
Expand All @@ -106,66 +159,5 @@ def get_server_args(cls):
]


class TestNgramExternalSamSmoke(CustomTestCase):
model = DEFAULT_TARGET_MODEL_NGRAM
base_url = DEFAULT_URL_FOR_TEST
attention_backends = ("triton", "flashinfer")

def get_server_args(self, attention_backend):
return DEFAULT_SERVER_ARGS + [
"--attention-backend",
attention_backend,
"--speculative-ngram-external-corpus-path",
self.external_corpus_path,
"--speculative-ngram-external-sam-budget",
"4",
]

@classmethod
def setUpClass(cls):
envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False)
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", prefix="ngram_external_sam_", delete=False
) as f:
for record in EXTERNAL_SAM_CORPUS_RECORDS:
f.write(json.dumps(record))
f.write("\n")
cls.external_corpus_path = f.name
cls.addClassCleanup(_safe_remove, cls.external_corpus_path)

def _run_external_sam_smoke(self, attention_backend):
process = popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=self.get_server_args(attention_backend),
)
try:
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
},
},
timeout=120,
)
self.assertEqual(response.status_code, 200, response.text)
response_json = response.json()
self.assertIn("text", response_json)
self.assertIn("meta_info", response_json)
self.assertGreater(response_json["meta_info"]["completion_tokens"], 0)
finally:
_safe_kill_process(process)

def test_generate_with_external_sam(self):
for attention_backend in self.attention_backends:
with self.subTest(attention_backend=attention_backend):
self._run_external_sam_smoke(attention_backend)


if __name__ == "__main__":
unittest.main()
Loading