Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Code Structures

- `eval`: Llama 3 evaluation and Loogle evaluation.
- `lang`: The frontend language.
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities.
Expand Down
File renamed without changes.
157 changes: 157 additions & 0 deletions python/sglang/eval/loogle_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List

import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm


def get_client(api_url: str) -> openai.AsyncOpenAI:
if os.getenv("OPENAI_API_KEY") is None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return openai.AsyncOpenAI(base_url=api_url)


def get_dataset():
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")


async def fetch_response(
client: openai.AsyncOpenAI,
context: str,
question: str,
semaphore: asyncio.Semaphore,
index: int,
model: str,
output_dir: Path,
):
output_file = output_dir / f"response_{index}.pkl"
if output_file.exists():
return

prompt = (
"Please answer the question based on the long texts below.\n"
f"{context}\n"
f"Question: {question}\n"
"Answer:"
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]

async with semaphore:
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
except openai.BadRequestError as e:
with open(output_file, "wb") as f:
pickle.dump({"error": str(e)}, f)
return

with open(output_file, "wb") as f:
pickle.dump(response, f)


async def benchmark(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

client = get_client(args.api_url)
semaphore = asyncio.Semaphore(args.max_concurrency)

tasks: List[asyncio.Task] = []
for idx, ex in enumerate(dataset):
tasks.append(
asyncio.create_task(
fetch_response(
client,
ex["context"],
ex["question"],
semaphore,
idx,
args.model,
output_dir,
)
)
)

for _ in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
):
await _


def analyse(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)

device = "cuda" if torch.cuda.is_available() else "cpu"
scorer = BERTScorer(lang="en", device=device)

hyps: List[str] = []
refs: List[str] = []
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
pkl_file = output_dir / f"response_{idx}.pkl"
if not pkl_file.exists():
raise FileNotFoundError(pkl_file)

response = pickle.load(open(pkl_file, "rb"))
if isinstance(response, dict) and "error" in response:
continue

hyps.append(response.choices[0].message.content.strip())
refs.append(ex["answer"])

if not hyps:
print("No valid responses to score!")
return

batch_size = 64
all_f1: List[float] = []
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
h_batch = hyps[i : i + batch_size]
r_batch = refs[i : i + batch_size]
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
all_f1.extend([float(x) for x in f1_scores])

avg = sum(all_f1) / len(all_f1)
print(f"Average BERTScore (F1): {avg:.2%}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run benchmark and evaluation in one go."
)
parser.add_argument(
"--api-url",
default="http://127.0.0.1:30000/v1",
help="OpenAI‑compatible API base URL",
)
parser.add_argument(
"--model",
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
help="Model name or ID",
)
parser.add_argument(
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
)
parser.add_argument(
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
)
args = parser.parse_args()

asyncio.run(benchmark(args))

analyse(args)
Loading