Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 1 addition & 265 deletions test/registered/spec/eagle/test_eagle_infer_a.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
import os
import random
import unittest

import requests
import torch

import sglang as sgl
from sglang.srt.utils import kill_process_tree
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.test_utils import (
DEFAULT_DRAFT_MODEL_EAGLE,
DEFAULT_DRAFT_MODEL_EAGLE3,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TARGET_MODEL_EAGLE,
DEFAULT_TARGET_MODEL_EAGLE3,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)

register_cuda_ci(est_time=561, suite="stage-b-test-1-gpu-large")

torch_dtype = torch.float16
prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
register_cuda_ci(est_time=250, suite="stage-b-test-1-gpu-large")


class TestEAGLEEngine(CustomTestCase):
Expand Down Expand Up @@ -204,255 +190,5 @@ class TestEAGLE3Engine(TestEAGLEEngine):
}


class TestEAGLERadixCache(CustomTestCase):
BASE_CONFIG = {
"model_path": DEFAULT_TARGET_MODEL_EAGLE3,
"speculative_draft_model_path": DEFAULT_DRAFT_MODEL_EAGLE3,
"speculative_algorithm": "EAGLE3",
"speculative_num_steps": 2,
"speculative_eagle_topk": 2,
"speculative_num_draft_tokens": 5,
"mem_fraction_static": 0.7,
"dtype": "float16",
"trust_remote_code": True,
"attention_backend": "fa3",
"skip_server_warmup": True,
"cuda_graph_max_bs": 5,
}

def test_correctness(self):
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
configs = [
# Basic config
self.BASE_CONFIG,
# Chunked prefill & Page Size > 1
{**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4},
{**self.BASE_CONFIG, "page_size": 4},
# Large page size tend to expose IMA bugs.
{**self.BASE_CONFIG, "page_size": 256},
{**self.BASE_CONFIG, "cuda_graph_bs": [5], "page_size": 4},
# Disable CUDA Graph
{
**self.BASE_CONFIG,
"disable_cuda_graph": True,
"page_size": 4,
},
]

for i, config in enumerate(configs):
with self.subTest(i=i):
print(f"{config=}")
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
self._test_acc_length(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
print("=" * 100)
del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"]

def _test_acc_length(self, engine):
warmup_prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(warmup_prompt, sampling_params)
test_prompt = [
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
]
output = engine.generate(test_prompt, sampling_params)
output = output[0]

if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0

speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=:.4f}, {speed=}")

self.assertGreater(acc_length, 2.5)

def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 50}

outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)

print(f"{engine.get_server_info()=}")

avg_spec_accept_length = engine.get_server_info()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.0)


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtend(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_TARGET_MODEL_EAGLE,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_DRAFT_MODEL_EAGLE,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"fa3",
],
)
cls.accept_len_threshold = 1.50

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_one_batch_accept_length(self):
resp = requests.get(self.base_url + "/flush_cache")
self.assertEqual(resp.status_code, 200)

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
url = self.base_url + "/generate"
data = {
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 512,
},
}
response = requests.post(url, json=data)
self.assertEqual(response.status_code, 200)
outputs = response.json()
for i in range(len(prompts)):
output = outputs[i]
if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0

print(f"{acc_length=}")
self.assertGreater(acc_length, self.accept_len_threshold)


class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_TARGET_MODEL_EAGLE,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_DRAFT_MODEL_EAGLE,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.50


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_TARGET_MODEL_EAGLE,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_DRAFT_MODEL_EAGLE,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"triton",
],
)
cls.accept_len_threshold = 1.50


@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.85


if __name__ == "__main__":
unittest.main()
78 changes: 41 additions & 37 deletions test/registered/spec/eagle/test_eagle_infer_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
from sglang.test.server_fixtures.eagle_fixture import EagleServerBase
from sglang.test.test_utils import DEFAULT_TARGET_MODEL_EAGLE, run_logprob_check

register_cuda_ci(est_time=1100, suite="stage-b-test-1-gpu-large")
register_cuda_ci(est_time=600, suite="stage-b-test-1-gpu-large")


class TestEAGLEServerBasic(EagleServerBase):
"""Core tests that run on every server config variant."""

extra_args = ["--chunked-prefill-size", 128, "--max-running-requests", 8]

# FIXME(lsyin): move the test methods to kits
Expand All @@ -42,27 +44,6 @@ def test_request_abort(self):
for p in threads:
p.join()

def test_radix_attention(self):
run_radix_attention_test(self.base_url)
self.assertIsNone(self.process.poll())

def test_max_token_one(self):
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=1,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)

# Just run and check it does not hang
metrics = run_gsm8k_eval(args)
self.assertGreater(metrics["output_throughput"], 50)

def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")

Expand Down Expand Up @@ -96,6 +77,44 @@ def test_gsm8k(self):
# Wait a little bit so that the memory check happens.
time.sleep(4)


class TestEAGLEServerAdditional(TestEAGLEServerBasic):
spec_topk = 5
spec_steps = 8
spec_tokens = 64
extra_args = [
"--max-running-requests",
8,
"--cuda-graph-max-bs",
5,
"--attention-backend",
"fa3",
"--page-size",
256,
"--dtype",
"float16",
]

def test_radix_attention(self):
run_radix_attention_test(self.base_url)
self.assertIsNone(self.process.poll())

def test_max_token_one(self):
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=1,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)

metrics = run_gsm8k_eval(args)
self.assertGreater(metrics["output_throughput"], 50)

def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
Expand Down Expand Up @@ -337,21 +356,6 @@ class TestEAGLEServerPageSizeTopk(TestEAGLEServerBasic):
]


class TestEAGLEServerPageSizeTopkFA3(TestEAGLEServerBasic):
# default topk=8 and tokens=64
spec_topk = 5
spec_steps = 8
spec_tokens = 64

extra_args = [
"--page-size=256",
"--attention-backend=fa3",
"--cuda-graph-max-bs=5",
"--dtype=float16",
"--max-running-requests=8",
]


class TestEAGLEAbortAll(AbortAllMixin, EagleServerBase):
abort_all_max_new_tokens = 4000
extra_args = ["--max-running-requests=8"]
Expand Down
Loading