Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions benchmarks/diffusion/diffusion_benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import argparse
import ast
import asyncio
import base64
import glob
import json
import logging
Expand Down Expand Up @@ -818,6 +819,71 @@ def calculate_metrics(
return metrics


def _save_generated_outputs(
outputs: list[RequestFuncOutput],
requests_list: list[RequestFuncInput],
save_dir: str,
) -> None:
"""Decode and save base64 images/videos from successful responses."""
os.makedirs(save_dir, exist_ok=True)
saved = 0

for idx, (req, out) in enumerate(zip(requests_list, outputs)):
if not out.success or not out.response_body:
continue

media_urls: list[str] = []

# Chat-completions style: choices[*].message.content[*].image_url.url
choices = out.response_body.get("choices", [])
if isinstance(choices, list):
for choice in choices:
content = (choice or {}).get("message", {}).get("content")
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict) or item.get("type") != "image_url":
continue
url = (item.get("image_url") or {}).get("url", "")
if isinstance(url, str) and url.startswith("data:"):
media_urls.append(url)

# Images endpoint style: data[*].b64_json
data_items = out.response_body.get("data", [])
if isinstance(data_items, list):
for data_item in data_items:
if not isinstance(data_item, dict):
continue
b64_json = data_item.get("b64_json", "")
if isinstance(b64_json, str) and b64_json:
media_urls.append(f"data:image/png;base64,{b64_json}")

for img_idx, url in enumerate(media_urls):
if "," not in url:
continue

try:
header, b64_data = url.split(",", 1)
ext = "png"
if "image/jpeg" in header:
ext = "jpg"
elif "image/webp" in header:
ext = "webp"
elif "video/mp4" in header:
ext = "mp4"

img_bytes = base64.b64decode(b64_data)
fname = f"req_{idx:04d}_{img_idx}.{ext}"
fpath = os.path.join(save_dir, fname)
with open(fpath, "wb") as f:
f.write(img_bytes)
saved += 1
except Exception as e:
print(f"Warning: failed to save image for request {idx}: {e}")

print(f"Saved {saved} generated image(s) to {save_dir}")


def wait_for_service(base_url: str, timeout: int = 120) -> None:
print(f"Waiting for service at {base_url}...")
start_time = time.time()
Expand Down Expand Up @@ -995,6 +1061,9 @@ async def limited_request_func(req, session, pbar):

print("\n" + "=" * 60)

if args.save_dir:
_save_generated_outputs(outputs, requests_list, args.save_dir)

if args.output_file:
with open(args.output_file, "w") as f:
json.dump(metrics, f, indent=2)
Expand Down Expand Up @@ -1069,8 +1138,10 @@ async def limited_request_func(req, session, pbar):
parser.add_argument(
"--warmup-num-inference-steps",
type=int,
default=1,
help="num_inference_steps used for warmup requests.",
default=2,
help="num_inference_steps used for warmup requests. "
"Must be >= 2 to ensure at least one denoising step is executed "
"(some models, e.g. Bagel, run num_timesteps-1 denoising iterations).",
)
parser.add_argument("--width", type=int, default=None, help="Image/Video width.")
parser.add_argument("--height", type=int, default=None, help="Image/Video height.")
Expand Down Expand Up @@ -1104,6 +1175,13 @@ async def limited_request_func(req, session, pbar):
default=3.0,
help="SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).",
)
parser.add_argument(
"--save-dir",
type=str,
default=None,
help="Directory to save generated images/outputs for visual inspection. "
"If not set, generated outputs are discarded after metric collection.",
)
parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.")
parser.add_argument(
"--enable-negative-prompt",
Expand Down
Loading
Loading