diff --git a/test/registered/spec/test_ngram_speculative_decoding.py b/test/registered/spec/test_ngram_speculative_decoding.py index 690e6ec45f8f..f80b1e646dea 100644 --- a/test/registered/spec/test_ngram_speculative_decoding.py +++ b/test/registered/spec/test_ngram_speculative_decoding.py @@ -1,6 +1,3 @@ -import json -import os -import tempfile import unittest import requests @@ -35,21 +32,6 @@ 0.8, ] -EXTERNAL_SAM_CORPUS_RECORDS = [ - "The capital of France is Paris.", - "The answer to life, the universe, and everything is 42.", -] - - -def _safe_remove(path: str): - if os.path.exists(path): - os.remove(path) - - -def _safe_kill_process(process): - if process is not None and process.poll() is None: - kill_process_tree(process.pid) - class TestNgramSpeculativeDecodingBase(GSM8KMixin, CustomTestCase): model = DEFAULT_TARGET_MODEL_NGRAM @@ -91,7 +73,78 @@ def get_server_args(cls): class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase): @classmethod def get_server_args(cls): - return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] + return DEFAULT_SERVER_ARGS + [ + "--attention-backend", + "flashinfer", + "--speculative-ngram-external-sam-budget", + "8", + ] + + def test_output_as_corpus_boosts_accept_length(self): + """Baseline → HTTP add corpus → verify accept length boost.""" + prompts = [ + "The capital of France is", + "In mathematics, the Pythagorean theorem states that", + "The speed of light in a vacuum is approximately", + "Water boils at a temperature of", + "The largest planet in our solar system is", + ] + max_new_tokens = 128 + num_rounds = 3 + + def generate_batch(): + outputs = [] + for prompt in prompts: + resp = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + }, + timeout=120, + ) + self.assertEqual(resp.status_code, 200, resp.text) + outputs.append(resp.json()["text"]) + return outputs + + def get_accept_length(): + info = requests.get(self.base_url + "/get_server_info").json() + return info["internal_states"][0]["avg_spec_accept_length"] + + # Phase 1: baseline — no SAM corpus loaded, only trie + generated_outputs = [] + for _ in range(num_rounds): + generated_outputs = generate_batch() + baseline_accept_len = get_accept_length() + print(f"\n Baseline accept length (no SAM): {baseline_accept_len:.2f}") + + # Flush cache so phase 2 starts clean + requests.post(self.base_url + "/flush_cache", timeout=30) + + # Phase 2: add generated outputs as corpus via HTTP API + resp = requests.post( + self.base_url + "/add_external_corpus", + json={"corpus_id": "bench", "documents": generated_outputs}, + timeout=120, + ) + self.assertEqual(resp.status_code, 200, resp.text) + self.assertTrue(resp.json()["success"], resp.json().get("message")) + + for _ in range(num_rounds): + generate_batch() + sam_accept_len = get_accept_length() + print(f" SAM accept length (output as corpus): {sam_accept_len:.2f}") + print(f" Speedup: {sam_accept_len / baseline_accept_len:.2f}x") + + self.assertGreater( + sam_accept_len, + baseline_accept_len * 2.0, + f"SAM accept length ({sam_accept_len:.2f}) should be at least 2x " + f"baseline ({baseline_accept_len:.2f}) when corpus matches output", + ) class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase): @@ -106,66 +159,5 @@ def get_server_args(cls): ] -class TestNgramExternalSamSmoke(CustomTestCase): - model = DEFAULT_TARGET_MODEL_NGRAM - base_url = DEFAULT_URL_FOR_TEST - attention_backends = ("triton", "flashinfer") - - def get_server_args(self, attention_backend): - return DEFAULT_SERVER_ARGS + [ - "--attention-backend", - attention_backend, - "--speculative-ngram-external-corpus-path", - self.external_corpus_path, - "--speculative-ngram-external-sam-budget", - "4", - ] - - @classmethod - def setUpClass(cls): - envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False) - envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", prefix="ngram_external_sam_", delete=False - ) as f: - for record in EXTERNAL_SAM_CORPUS_RECORDS: - f.write(json.dumps(record)) - f.write("\n") - cls.external_corpus_path = f.name - cls.addClassCleanup(_safe_remove, cls.external_corpus_path) - - def _run_external_sam_smoke(self, attention_backend): - process = popen_launch_server( - self.model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=self.get_server_args(attention_backend), - ) - try: - response = requests.post( - self.base_url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 8, - }, - }, - timeout=120, - ) - self.assertEqual(response.status_code, 200, response.text) - response_json = response.json() - self.assertIn("text", response_json) - self.assertIn("meta_info", response_json) - self.assertGreater(response_json["meta_info"]["completion_tokens"], 0) - finally: - _safe_kill_process(process) - - def test_generate_with_external_sam(self): - for attention_backend in self.attention_backends: - with self.subTest(attention_backend=attention_backend): - self._run_external_sam_smoke(attention_backend) - - if __name__ == "__main__": unittest.main()