diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 754fe9a79fa8..3ec9aa13a043 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -63,6 +63,7 @@ "test_block_int8.py", "test_int8_kernel.py", "test_reasoning_content.py", + "test_cache.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", diff --git a/test/srt/test_cache.py b/test/srt/test_cache.py new file mode 100644 index 000000000000..9bbef61abf35 --- /dev/null +++ b/test/srt/test_cache.py @@ -0,0 +1,221 @@ +""" +python3 -m unittest test_cache +""" + +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +# No need to test ChunkCache, as it's tested in test_chunked_prefill.py +class TestDisableRadixCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--disable-radix-cache", + "--enable-cache-report", + "--enable-metrics", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def test_throughput(self): + # Warmup + res = self.run_decode(16) + + max_tokens = 256 + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(f"{res=}") + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + self.assertGreaterEqual(throughput, 80) + + +class TestRadixCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-cache-report", + "--enable-metrics", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def test_evict(self, num_iters=2): + # Run a few times to ensure eviction happens + print(f"Running test_mmlu {num_iters} times...") + start_time = time.time() + for i in range(num_iters): + print(f"Running iteration {i+1}/{num_iters}") + self.test_mmlu() + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + + def test_throughput(self): + # Warmup + res = self.run_decode(16) + + max_tokens = 256 + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(f"{res=}") + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + self.assertGreaterEqual(throughput, 80) + + +class TestHiCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-hierarchical-cache", + "--enable-cache-report", + "--enable-metrics", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def test_evict(self, num_iters=2): + print(f"Running test_mmlu {num_iters} times...") + start_time = time.time() + for i in range(num_iters): + print(f"Running iteration {i+1}/{num_iters}") + self.test_mmlu() + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + + def test_throughput(self): + # Warmup + res = self.run_decode(16) + + max_tokens = 256 + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(f"{res=}") + throughput = max_tokens / (tok - tic) + print(f"Decode Throughput: {throughput} tokens/s") + self.assertGreaterEqual(throughput, 80) + + +if __name__ == "__main__": + unittest.main()