Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8

- name: Benchmark VLM offline throughput
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_vlm_offline_throughput

- name: Benchmark VLM online latency
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_vlm_online_latency

performance-test-2-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
Expand Down
150 changes: 149 additions & 1 deletion python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class RequestFuncInput:
output_len: int
model: str
lora_name: str
image_data: str
extra_request_body: Dict[str, Any]


Expand Down Expand Up @@ -347,6 +348,11 @@ async def async_request_sglang_generate(
"logprob_start_len": -1,
**request_func_input.extra_request_body,
}

# Add image data if available
if request_func_input.image_data:
payload["image_data"] = request_func_input.image_data

headers = get_auth_headers()

output = RequestFuncOutput()
Expand Down Expand Up @@ -510,6 +516,13 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
args=args,
)
elif args.dataset_name == "mmmu":
input_requests = sample_mmmu_requests(
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.random_output_len,
random_sample=True,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests
Expand Down Expand Up @@ -597,6 +610,121 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename


def sample_mmmu_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
random_sample: bool = True,
) -> List[Tuple[str, int, int]]:
"""
Sample requests from the MMMU dataset using HuggingFace datasets.

Args:
num_requests: Number of requests to sample.
tokenizer: Tokenizer to use for token counting.
fixed_output_len: If provided, use this fixed output length for all requests.
random_sample: Whether to randomly sample or take the first N.

Returns:
List of tuples (prompt, prompt_token_len, output_token_len).
"""
try:
import base64
import io

from datasets import load_dataset
except ImportError:
raise ImportError("Please install datasets: pip install datasets")

print("Loading MMMU dataset from HuggingFace...")

try:
print("Attempting to load MMMU Math dataset...")
mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
print(
f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
)
except Exception as e:
print(f"Failed to load MMMU Math dataset: {e}")
raise ValueError(f"Failed to load MMMU dataset: {e}")

# Sample from the dataset
if len(mmmu_dataset) > num_requests:
if random_sample:
# Random sample
indices = random.sample(range(len(mmmu_dataset)), num_requests)
sample_dataset = mmmu_dataset.select(indices)
else:
# Take first N
sample_dataset = mmmu_dataset.select(
range(min(num_requests, len(mmmu_dataset)))
)
else:
print(f"Dataset has less than {num_requests} examples, using all examples")
sample_dataset = mmmu_dataset

print(f"Selected {len(sample_dataset)} examples for benchmarking")

# Create prompts
filtered_dataset = []

for i, example in enumerate(sample_dataset):
try:
# Extract image_1
image = example.get("image_1")

if image is not None:
if hasattr(image, "save"):
# Convert RGBA images to RGB before encoding
if image.mode == "RGBA":
image = image.convert("RGB")

# Encode image to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
image_path = f"data:image/jpeg;base64,{img_str}"
else:
continue

# Extract the question
question = example.get("question")

# Create the prompt with image, question
prompt = f"Question: {question}\n\nAnswer: "
prompt = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "text", "text": prompt},
],
}
],
add_generation_prompt=True,
tokenize=False,
)
prompt = f"<image>{image_path}</image>{prompt}"

# Calculate token lengths
# Note: This is approximate since we're not rendering the actual image tokens
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = (
len(prompt_token_ids) + 512
) # Add estimate for image tokens

output_len = fixed_output_len if fixed_output_len is not None else 256

filtered_dataset.append((prompt, prompt_len, output_len))

except Exception as e:
print(f"Error processing example {i}: {e}")

print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
return filtered_dataset


def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
Expand Down Expand Up @@ -1004,6 +1132,15 @@ async def limited_request_func(request_func_input, pbar):
else:
lora_name = None

if "<image>" in test_prompt:
import re

image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt)
image_data = image_match.group(1) if image_match else None
test_prompt = image_match.group(2) if image_match else test_prompt
else:
image_data = None

# Create the test input once
test_input = RequestFuncInput(
model=model_id,
Expand All @@ -1012,6 +1149,7 @@ async def limited_request_func(request_func_input, pbar):
prompt_len=test_prompt_len,
output_len=min(test_output_len, 32),
lora_name=lora_name,
image_data=image_data,
extra_request_body=extra_request_body,
)

Expand Down Expand Up @@ -1063,13 +1201,23 @@ async def limited_request_func(request_func_input, pbar):
else:
lora_name = None

if "<image>" in prompt:
import re

image_match = re.search(r"<image>(.*?)</image>(.*)", prompt)
image_data = image_match.group(1) if image_match else None
prompt = image_match.group(2) if image_match else prompt
else:
image_data = None

request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
lora_name=lora_name,
image_data=image_data,
extra_request_body=extra_request_body,
)
tasks.append(
Expand Down Expand Up @@ -1444,7 +1592,7 @@ def __call__(self, parser, namespace, values, option_string=None):
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix", "mmmu"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-VL-3B-Instruct"
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST = "qwen2-vl"

DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
Expand Down
54 changes: 54 additions & 0 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_FP8,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST,
CustomTestCase,
is_in_ci,
run_bench_serving,
Expand Down Expand Up @@ -148,6 +150,58 @@ def test_online_latency_default(self):
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)

def test_vlm_offline_throughput(self):
res = run_bench_serving(
model=DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
num_prompts=200,
request_rate=float("inf"),
other_server_args=[
"--chat-template",
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST,
"--mem-fraction-static",
"0.7",
],
dataset_name="mmmu",
)

if is_in_ci():
write_github_step_summary(
f"### test_vlm_offline_throughput\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
)
if os.getenv("SGLANG_AMD_CI") == "1":
self.assertGreater(res["output_throughput"], 2000)
# TODO: not set yet, need AMD machine
else:
self.assertGreater(res["output_throughput"], 2500)

def test_vlm_online_latency(self):
res = run_bench_serving(
model=DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
num_prompts=50,
request_rate=1,
other_server_args=[
"--chat-template",
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST,
"--mem-fraction-static",
"0.7",
],
dataset_name="mmmu",
)

if is_in_ci():
write_github_step_summary(
f"### test_vlm_online_latency\n"
f'median_e2e_latency_ms: {res["median_e2e_latency_ms"]:.2f} ms\n'
)
self.assertLess(res["median_e2e_latency_ms"], 16000)
if os.getenv("SGLANG_AMD_CI") == "1":
self.assertLess(res["median_ttft_ms"], 150)
# TODO: not set yet, need AMD machine
else:
self.assertLess(res["median_ttft_ms"], 90)
self.assertLess(res["median_itl_ms"], 8)

def test_online_latency_eagle(self):
res = run_bench_serving(
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
Expand Down
4 changes: 2 additions & 2 deletions test/srt/test_skip_tokenizer_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sglang.test.test_utils import (
DEFAULT_IMAGE_URL,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_VLM_MODEL_NAME,
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
Expand Down Expand Up @@ -195,7 +195,7 @@ def setUpClass(cls):
cls.image_url = DEFAULT_IMAGE_URL
response = requests.get(cls.image_url)
cls.image = Image.open(BytesIO(response.content))
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False)
cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True)
cls.base_url = DEFAULT_URL_FOR_TEST
Expand Down
Loading