Skip to content
Open
89 changes: 87 additions & 2 deletions benchmarks/diffusion/diffusion_benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import argparse
import ast
import asyncio
import base64
import glob
import json
import logging
Expand Down Expand Up @@ -818,6 +819,73 @@ def calculate_metrics(
return metrics


def _save_generated_outputs(
outputs: list[RequestFuncOutput],
requests_list: list[RequestFuncInput],
save_dir: str,
) -> None:
"""Decode and save base64 images/videos from successful responses."""
os.makedirs(save_dir, exist_ok=True)
saved = 0
failed = 0

for idx, (req, out) in enumerate(zip(requests_list, outputs)):
if not out.success or not out.response_body:
continue

media_urls: list[str] = []

# Chat-completions style: choices[*].message.content[*].image_url.url
choices = out.response_body.get("choices", [])
if isinstance(choices, list):
for choice in choices:
content = (choice or {}).get("message", {}).get("content")
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict) or item.get("type") != "image_url":
continue
url = (item.get("image_url") or {}).get("url", "")
if isinstance(url, str) and url.startswith("data:"):
media_urls.append(url)

# Images endpoint style: data[*].b64_json
data_items = out.response_body.get("data", [])
if isinstance(data_items, list):
for data_item in data_items:
if not isinstance(data_item, dict):
continue
b64_json = data_item.get("b64_json", "")
if isinstance(b64_json, str) and b64_json:
media_urls.append(f"data:image/png;base64,{b64_json}")

for img_idx, url in enumerate(media_urls):
if "," not in url:
continue

try:
header, b64_data = url.split(",", 1)
ext = "png"
if "image/jpeg" in header:
ext = "jpg"
elif "image/webp" in header:
ext = "webp"
elif "video/mp4" in header:
ext = "mp4"

img_bytes = base64.b64decode(b64_data)
fname = f"req_{idx:04d}_{img_idx}.{ext}"
fpath = os.path.join(save_dir, fname)
with open(fpath, "wb") as f:
f.write(img_bytes)
saved += 1
except Exception as e:
failed += 1
logger.warning(f"Failed to save image for request {idx}: {e}", exc_info=True)

logger.info(f"Saved {saved} generated image(s) to {save_dir}. Failed to save {failed} image(s).")


def wait_for_service(base_url: str, timeout: int = 120) -> None:
print(f"Waiting for service at {base_url}...")
start_time = time.time()
Expand Down Expand Up @@ -995,6 +1063,9 @@ async def limited_request_func(req, session, pbar):

print("\n" + "=" * 60)

if args.save_dir:
_save_generated_outputs(outputs, requests_list, args.save_dir)

if args.output_file:
with open(args.output_file, "w") as f:
json.dump(metrics, f, indent=2)
Expand Down Expand Up @@ -1066,11 +1137,18 @@ async def limited_request_func(req, session, pbar):
default=1,
help="Number of warmup requests to run before measurement.",
)
# NOTE Changed default from 1 to 2 because some models (e.g., Bagel) run
# `num_timesteps - 1` denoising iterations. A default of 1 results in 0 steps,
# which causes errors.
# TODO If this slightly longer warmup causes regression issues for other
# diffusion pipelines in the future, consider implementing model-specific
# overrides instead of a global default.
parser.add_argument(
"--warmup-num-inference-steps",
type=int,
default=1,
help="num_inference_steps used for warmup requests.",
default=2,
help="Number of inference steps used for warmup requests. "
"Default is 2 to ensure at least one denoising step is executed.",
)
parser.add_argument("--width", type=int, default=None, help="Image/Video width.")
parser.add_argument("--height", type=int, default=None, help="Image/Video height.")
Expand Down Expand Up @@ -1104,6 +1182,13 @@ async def limited_request_func(req, session, pbar):
default=3.0,
help="SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).",
)
parser.add_argument(
"--save-dir",
type=str,
default=None,
help="Directory to save generated images/outputs for visual inspection. "
"If not set, generated outputs are discarded after metric collection.",
)
parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.")
parser.add_argument(
"--enable-negative-prompt",
Expand Down
Loading
Loading