Skip to content
Closed
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"test_block_int8.py",
"test_int8_kernel.py",
"test_reasoning_content.py",
"test_cache.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
Expand Down
221 changes: 221 additions & 0 deletions test/srt/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""
python3 -m unittest test_cache
"""

import time
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


# No need to test ChunkCache, as it's tested in test_chunked_prefill.py
class TestDisableRadixCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--disable-radix-cache",
"--enable-cache-report",
"--enable-metrics",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)

def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()

def test_throughput(self):
# Warmup
res = self.run_decode(16)

max_tokens = 256
tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
print(f"{res=}")
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 80)


class TestRadixCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-cache-report",
"--enable-metrics",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)

def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()

def test_evict(self, num_iters=2):
# Run a few times to ensure eviction happens
print(f"Running test_mmlu {num_iters} times...")
start_time = time.time()
for i in range(num_iters):
print(f"Running iteration {i+1}/{num_iters}")
self.test_mmlu()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")

def test_throughput(self):
# Warmup
res = self.run_decode(16)

max_tokens = 256
tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
print(f"{res=}")
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 80)


class TestHiCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--enable-cache-report",
"--enable-metrics",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)

def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()

def test_evict(self, num_iters=2):
print(f"Running test_mmlu {num_iters} times...")
start_time = time.time()
for i in range(num_iters):
print(f"Running iteration {i+1}/{num_iters}")
self.test_mmlu()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")

def test_throughput(self):
# Warmup
res = self.run_decode(16)

max_tokens = 256
tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
print(f"{res=}")
throughput = max_tokens / (tok - tic)
print(f"Decode Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 80)


if __name__ == "__main__":
unittest.main()