Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Code Structures

- `eval`: The evaluation utilities.
- `lang`: The frontend language.
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities.
Expand All @@ -11,6 +12,5 @@
- `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `utils.py`: Common utilities.
- `version.py`: Version info.
File renamed without changes.
157 changes: 157 additions & 0 deletions python/sglang/eval/loogle_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List

import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm


def get_client(api_url: str) -> openai.AsyncOpenAI:
if os.getenv("OPENAI_API_KEY") is None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return openai.AsyncOpenAI(base_url=api_url)


def get_dataset():
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")


async def fetch_response(
client: openai.AsyncOpenAI,
context: str,
question: str,
semaphore: asyncio.Semaphore,
index: int,
model: str,
output_dir: Path,
):
output_file = output_dir / f"response_{index}.pkl"
if output_file.exists():
return

prompt = (
"Please answer the question based on the long texts below.\n"
f"{context}\n"
f"Question: {question}\n"
"Answer:"
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]

async with semaphore:
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
except openai.BadRequestError as e:
with open(output_file, "wb") as f:
pickle.dump({"error": str(e)}, f)
return

with open(output_file, "wb") as f:
pickle.dump(response, f)


async def benchmark(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

client = get_client(args.api_url)
semaphore = asyncio.Semaphore(args.max_concurrency)

tasks: List[asyncio.Task] = []
for idx, ex in enumerate(dataset):
tasks.append(
asyncio.create_task(
fetch_response(
client,
ex["context"],
ex["question"],
semaphore,
idx,
args.model,
output_dir,
)
)
)

for _ in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
):
await _


def analyse(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)

device = "cuda" if torch.cuda.is_available() else "cpu"
scorer = BERTScorer(lang="en", device=device)

hyps: List[str] = []
refs: List[str] = []
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
pkl_file = output_dir / f"response_{idx}.pkl"
if not pkl_file.exists():
raise FileNotFoundError(pkl_file)

response = pickle.load(open(pkl_file, "rb"))
if isinstance(response, dict) and "error" in response:
continue

hyps.append(response.choices[0].message.content.strip())
refs.append(ex["answer"])

if not hyps:
print("No valid responses to score!")
return

batch_size = 64
all_f1: List[float] = []
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
h_batch = hyps[i : i + batch_size]
r_batch = refs[i : i + batch_size]
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
all_f1.extend([float(x) for x in f1_scores])

avg = sum(all_f1) / len(all_f1)
print(f"Average BERTScore (F1): {avg:.2%}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run benchmark and evaluation in one go."
)
parser.add_argument(
"--api-url",
default="http://127.0.0.1:30000/v1",
help="OpenAI‑compatible API base URL",
)
parser.add_argument(
"--model",
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
help="Model name or ID",
)
parser.add_argument(
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
)
parser.add_argument(
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
)
args = parser.parse_args()

asyncio.run(benchmark(args))

analyse(args)
Loading