Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions .buildkite/test-nightly-diffusion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,23 @@ steps:
if: *nightly_or_pr_label
commands:
- export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
- export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
- export CACHE_DIT_VERSION=1.3.0
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
- buildkite-agent artifact upload "tests/dfx/perf/results/benchmark_results_*.json"
- buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
# [HACK]: run upload in the same command block as pytest.
# Because `exit` aborts the entire commands list.
- |
set +e
pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
EXIT1=$$?
pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
EXIT2=$$?
pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
EXIT3=$$?
if [ $$EXIT1 -eq 0 ] || [ $$EXIT2 -eq 0 ] || [ $$EXIT3 -eq 0 ]; then
buildkite-agent artifact upload "tests/dfx/perf/results/diffusion_result_*.json"
buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
fi
exit $$((EXIT1 | EXIT2 | EXIT3))
agents:
queue: "mithril-h100-pool"
plugins:
Expand Down
28 changes: 23 additions & 5 deletions benchmarks/diffusion/diffusion_benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool
super().__init__(args, api_url, model)
self.num_prompts = args.num_prompts
self.enable_negative_prompt = enable_negative_prompt
self.num_input_images = max(1, args.num_input_images)
self.random_request_config = getattr(args, "random_request_config", None)
if self.random_request_config:
self.random_request_config = json.loads(self.random_request_config)
Expand All @@ -580,11 +581,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool

# Random image generate
if self.args.task in ["i2v", "ti2v", "ti2i", "i2i"]:
img = Image.new("RGB", (512, 512), (255, 255, 255))

image_path = os.path.join(tempfile.gettempdir(), "diffusion_benchmark_random_image.png")
self._random_image_path = [image_path]
img.save(image_path)
self._random_image_path = self._generate_random_image_paths()
else:
self._random_image_path = None

Expand Down Expand Up @@ -619,6 +616,18 @@ def __getitem__(self, idx: int) -> RequestFuncInput:
def get_requests(self) -> list[RequestFuncInput]:
return [self[i] for i in range(len(self))]

def _generate_random_image_paths(self) -> list[str]:
image_paths: list[str] = []
for image_idx in range(self.num_input_images):
img = Image.new("RGB", (512, 512), (255, 255, 255))
image_path = os.path.join(
tempfile.gettempdir(),
f"diffusion_benchmark_random_image_{image_idx}.png",
)
img.save(image_path)
image_paths.append(image_path)
return image_paths


def _compute_expected_latency_ms_from_base(req: RequestFuncInput, args, base_time_ms: float | None) -> float | None:
"""Compute expected execution time (ms) based on a base per-step-per-frame unit time.
Expand Down Expand Up @@ -1115,6 +1124,15 @@ async def limited_request_func(req, session, pbar):
'{"width":768,"height":768,"num_inference_steps":20,"weight":0.85}]'
),
)
parser.add_argument(
"--num-input-images",
type=int,
default=1,
help=(
"Number of synthetic input images to attach for image-conditioned tasks "
"(i2v, ti2v, ti2i, i2i) when using random dataset."
),
)

args = parser.parse_args()

Expand Down
170 changes: 152 additions & 18 deletions tests/dfx/perf/scripts/run_diffusion_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import time
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, cast

import psutil
import pytest

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
os.environ.setdefault("DIFFUSION_ATTENTION_BACKEND", "FLASH_ATTN")
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIFFUSION_ATTENTION_BACKEND is being defaulted to FLASH_ATTN at import time. This can force the FlashAttention backend even in environments where flash-attn isn’t installed (the backend raises ImportError and suggests using TORCH_SDPA), causing local runs to fail unexpectedly. Prefer leaving this env var unset by default (let platform selection decide), or set it conditionally only when FlashAttention is available / in CI where it’s guaranteed.

Suggested change
os.environ.setdefault("DIFFUSION_ATTENTION_BACKEND", "FLASH_ATTN")

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intended to use flash attention in benchmark


# ---------------------------------------------------------------------------
# Paths
Expand All @@ -50,6 +51,7 @@
# Populated lazily after CONFIG_FILE_PATH is resolved.
_SESSION_TIMESTAMP = datetime.now().strftime("%Y%m%d-%H%M%S")
_RESULT_LOCK = threading.Lock()
_BRANCHPOINT_COMMIT_SHA: str | None = None


def _get_config_file_from_argv() -> str | None:
Expand Down Expand Up @@ -110,7 +112,7 @@ def load_configs(config_path: str) -> list[dict[str, Any]]:
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)

_config_stem = Path(CONFIG_FILE_PATH).stem # e.g. "test_qwen_image_vllm_omni"
AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"benchmark_results_{_config_stem}_{_SESSION_TIMESTAMP}.json"
AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"diffusion_result_{_config_stem}_{_SESSION_TIMESTAMP}.json"


def _append_to_aggregated_file(record: dict[str, Any]) -> None:
Expand Down Expand Up @@ -232,13 +234,13 @@ class DiffusionServer:

def __init__(
self,
model: str,
serve_args: list[str],
server_cfg: dict[str, Any],
*,
port: int | None = None,
) -> None:
self.model = model
self.serve_args = serve_args
self.server_cfg: dict[str, Any] = server_cfg
self.model = server_cfg["model"]
self.serve_args = server_cfg["serve_args"]
self.host = "127.0.0.1"
self.port = port if port is not None else _get_open_port()
self.proc: subprocess.Popen | None = None
Expand Down Expand Up @@ -299,6 +301,95 @@ def _build_serve_args(serve_args_dict: dict[str, Any]) -> list[str]:
return args


def _get_branchpoint_commit_sha() -> str:
"""Return the branch-point commit SHA against main.

Uses git command: ``git merge-base HEAD origin/main``.
"""
global _BRANCHPOINT_COMMIT_SHA
if _BRANCHPOINT_COMMIT_SHA is not None:
return _BRANCHPOINT_COMMIT_SHA

repo_root = Path(__file__).parent.parent.parent.parent
try:
sha = (
subprocess.check_output(
["git", "merge-base", "HEAD", "origin/main"],
cwd=str(repo_root),
stderr=subprocess.STDOUT,
text=True,
)
.strip()
.splitlines()[0]
)
_BRANCHPOINT_COMMIT_SHA = sha
except Exception as e:
print(f"Warning: failed to get branch-point commit SHA: {e}")
_BRANCHPOINT_COMMIT_SHA = ""
return _BRANCHPOINT_COMMIT_SHA


def _to_resolution_string(params: dict[str, Any]) -> str:
width = params.get("width", "unknown width")
height = params.get("height", "unknown height")
return f"{width}x{height}"


def _to_parallelism_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
parts: list[str] = []
if framework == "vllm-omni":
keys = [
"num-gpus",
"usp",
"ulysses-degree",
"ring",
"ring-degree",
"cfg-parallel-size",
"vae-patch-parallel-size",
"vae-use-tiling",
"tensor-parallel-size",
]
for key in keys:
if key in serve_args_dict:
parts.append(f"{key}={serve_args_dict[key]}")
return ",".join(parts) if parts else "none"


def _to_cache_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
if framework == "vllm-omni":
if "cache-backend" in serve_args_dict:
return str(serve_args_dict["cache-backend"])
return "disabled"


def _to_offload_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
selected: list[str] = []
if framework == "vllm-omni":
offload_keys = [
"enable-cpu-offload",
"enable-layerwise-offload",
]
for key in offload_keys:
if key in serve_args_dict:
selected.append(key)
return f"enabled({';'.join(selected)})" if selected else "disabled"


def _to_compile_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
if framework == "vllm-omni":
if "enforce-eager" in serve_args_dict:
return "disabled"
return "enabled"
return "disabled"


def _to_quantization_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
if framework == "vllm-omni":
quant = serve_args_dict.get("quantization")
return str(quant) if quant else "disabled"
return "disabled"


def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Return one server-config dict per unique test_name."""
seen: set[str] = set()
Expand All @@ -310,12 +401,14 @@ def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]
seen.add(test_name)
if cfg.get("server_type", "vllm-omni") != "vllm-omni":
raise ValueError(f"Unsupported server_type in config: {cfg.get('server_type')}")
serve_args_dict = cfg["server_params"].get("serve_args", {})
result.append(
{
"test_name": test_name,
"server_type": "vllm-omni",
"model": cfg["server_params"]["model"],
"serve_args": _build_serve_args(cfg["server_params"].get("serve_args", {})),
"serve_args_dict": serve_args_dict,
"serve_args": _build_serve_args(serve_args_dict),
"benchmark_backend": "vllm-omni",
"server_params": cfg["server_params"],
}
Expand All @@ -334,9 +427,7 @@ def _test_param_mapping(configs: list[dict[str, Any]]) -> dict[str, list[dict]]:

def _make_server(server_cfg: dict[str, Any]) -> DiffusionServer:
"""Factory: return a vLLM-Omni diffusion server instance for the config."""
model = server_cfg["model"]
serve_args = server_cfg["serve_args"]
return DiffusionServer(model=model, serve_args=serve_args)
return DiffusionServer(server_cfg=server_cfg)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -364,7 +455,6 @@ def diffusion_server(request):
print(f"\nStarting {server_type} server for test: {test_name}")
with _make_server(server_cfg) as server:
server.test_name = test_name
server.server_params = server_cfg["server_params"]
print(f"{server_type} server started successfully")
yield server
print(f"{server_type} server stopping…")
Expand Down Expand Up @@ -402,16 +492,18 @@ def run_benchmark(
params: dict[str, Any],
test_name: str,
backend: str = "vllm-omni",
server_params: dict[str, Any] | None = None,
server_cfg: dict[str, Any] | None = None,
source_file: str = "",
) -> dict[str, Any]:
"""Run diffusion_benchmark_serving.py as a subprocess and return parsed metrics.

The raw metrics are written to a temporary file by the subprocess. After
the run completes the metrics are merged with full metadata (test_name,
backend, benchmark_params, timestamp) and appended to the session-wide
aggregated JSON file (AGGREGATED_RESULT_FILE). The temporary file is
removed afterwards. Subprocess stdout/stderr are tee'd to a .log file
under BENCHMARK_RESULT_DIR/logs/; its path is stored in the record.
backend, benchmark_params, timestamp, flat reporting fields) and appended
to the session-wide aggregated JSON file (AGGREGATED_RESULT_FILE). The
temporary file is removed afterwards. Subprocess stdout/stderr are tee'd
to a .log file under BENCHMARK_RESULT_DIR/logs/; its path is stored in
the record.
"""
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

Expand Down Expand Up @@ -495,14 +587,55 @@ def run_benchmark(
finally:
tmp_result_file.unlink(missing_ok=True)

server_cfg = server_cfg or {}
serve_args_dict = server_cfg.get("serve_args_dict", {})
if not isinstance(serve_args_dict, dict):
serve_args_dict = {}

completed = metrics.get("completed_requests", metrics.get("completed", 0))
failed = metrics.get("failed_requests", metrics.get("failed", 0))

record: dict[str, Any] = {
"test_name": test_name,
"backend": backend,
"timestamp": timestamp,
"server_params": server_params,
"server_params": server_cfg.get("server_params"),
"benchmark_params": params,
"result": metrics,
"log_file": str(log_file),
"Model": model,
"Framework": backend,
"Hardware": "",
"Deployment": "",
"Task": params.get("task", "t2i"),
"Dataset": params.get("dataset", "random"),
"resolution": _to_resolution_string(params),
"Parallelism": _to_parallelism_string(backend, serve_args_dict),
"max_concurrency": params.get("max-concurrency", ""),
"Cache": _to_cache_string(backend, serve_args_dict),
"Quantization": _to_quantization_value(backend, serve_args_dict),
"offload": _to_offload_string(backend, serve_args_dict),
"compile": _to_compile_value(backend, serve_args_dict),
"Attn_backend": os.environ.get("DIFFUSION_ATTENTION_BACKEND", ""),
"num_inference_steps": params.get("num-inference-steps", ""),
"completed": completed,
"failed": failed,
"throughput_qps": metrics.get("throughput_qps"),
"latency_mean": metrics.get("latency_mean"),
"latency_median": metrics.get("latency_median"),
"latency_p99": metrics.get("latency_p99"),
"latency_p95": metrics.get("latency_p95"),
"latency_p50": metrics.get("latency_p50"),
"peak_memory_mb_max": metrics.get("peak_memory_mb_max"),
"peak_memory_mb_mean": metrics.get("peak_memory_mb_mean"),
"peak_memory_mb_median": metrics.get("peak_memory_mb_median"),
"stage_durations_mean": metrics.get("stage_durations_mean"),
"stage_durations_p50": metrics.get("stage_durations_p50"),
"stage_durations_p99": metrics.get("stage_durations_p99"),
"commit_sha": _get_branchpoint_commit_sha(),
"build_id": os.environ.get("BUILDKITE_BUILD_ID", ""),
"build_url": os.environ.get("BUILDKITE_BUILD_URL", ""),
"source_file": source_file,
}
_append_to_aggregated_file(record)
print(f"\n Result appended to: {AGGREGATED_RESULT_FILE}")
Expand Down Expand Up @@ -565,7 +698,8 @@ def test_diffusion_performance_benchmark(diffusion_server, benchmark_params):
params=params,
test_name=test_name,
backend=backend,
server_params=diffusion_server.server_params,
server_cfg=getattr(diffusion_server, "server_cfg", {}),
source_file=cast(str, CONFIG_FILE_PATH),
)

print(f"\n{'=' * 60}")
Expand Down
Loading
Loading