Skip to content
181 changes: 156 additions & 25 deletions benchmarks/diffusion/quantization_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@
--height 1024 --width 1024 \
--num-inference-steps 50 --seed 42

Offline-quantized checkpoint (SVDQuant / MXFP8 offline / ModelOpt FP4 etc.):
Use `--baseline-model` for the BF16 reference and `--quantization auto`
so the quantized run honors the on-disk `transformer/config.json
["quantization_config"]` instead of overriding it.

python benchmarks/diffusion/quantization_quality.py \
--baseline-model Tongyi-MAI/Z-Image-Turbo \
--model ultranationalism/Z-Image-Turbo-SVDQuant-NVFP4 \
--task t2i \
--quantization auto \
--prompts "a cup of coffee on a wooden table, morning light" \
--height 1024 --width 1024 \
--num-inference-steps 20 --seed 42

Output directory structure (--output-dir, default: ./quant_bench_output):
quant_bench_output/
baseline/ # BF16 outputs
Expand Down Expand Up @@ -136,8 +150,55 @@ def compute_lpips_video(
return float(np.mean(scores))


def _build_omni_kwargs(args, quantization=None):
"""Build kwargs dict for Omni() constructor."""
def _gpu_memory_gib(device_index: int = 0) -> float:
"""Total GPU memory currently used on `device_index`, in GiB.

Uses `nvidia-smi` because vllm-omni spawns model workers in separate
processes; `torch.cuda.memory_allocated()` from the bench driver
process reports 0 since the workers hold the model in their own CUDA
contexts. `nvidia-smi`'s `memory.used` aggregates all processes on the
device, which on a single-GPU benchmark gives us the right number.

Returns 0.0 if nvidia-smi is unavailable or fails — callers should
treat 0.0 as "unknown" (the markdown summary guards against divide-
by-zero in the reduction column).
"""
import shutil
import subprocess

if not shutil.which("nvidia-smi"):
return 0.0
try:
out = subprocess.run(
[
"nvidia-smi",
f"--id={device_index}",
"--query-gpu=memory.used",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
timeout=5,
check=True,
)
return int(out.stdout.strip()) / 1024 # MiB -> GiB
except (subprocess.CalledProcessError, ValueError, subprocess.TimeoutExpired):
return 0.0


def _build_omni_kwargs(args, quantization=None, model_override=None):
"""Build kwargs dict for Omni() constructor.

`model_override` lets the baseline and quantized runs target different
on-disk paths (offline-quantized checkpoints like SVDQuant ship their
own pipeline tree separate from the BF16 reference model).

`quantization="auto"` is a sentinel meaning "do not pass
`quantization_config` to Omni; let the on-disk
`transformer/config.json["quantization_config"]` drive the choice".
Useful for offline-quantized checkpoints where the method + per-layer
skip list are baked into the config file.
"""
from vllm_omni.diffusion.data import DiffusionParallelConfig

parallel_config = DiffusionParallelConfig(
Expand All @@ -146,23 +207,35 @@ def _build_omni_kwargs(args, quantization=None):
tensor_parallel_size=args.tensor_parallel_size,
)
kwargs = {
"model": args.model,
"model": model_override if model_override is not None else args.model,
"parallel_config": parallel_config,
"enforce_eager": args.enforce_eager,
}
if quantization:
if quantization and quantization != "auto":
kwargs["quantization_config"] = quantization
return kwargs


def _generate_image(omni, args, prompt, seed):
"""Generate a single image and return (PIL.Image, time_seconds, memory_gib)."""
"""Generate a single image; returns (PIL.Image, elapsed_s, peak_gib, weights_gib, activations_gib).

`weights_gib` is GPU memory snapshotted right before `generate()`
(the engine holds model params + persistent buffers; few transient
activations should be live). `peak_gib` is sampled right after
`generate()` returns — close to but not necessarily exactly the
high-water mark, since the diffusion loop may have freed temporary
buffers before returning. `activations_gib` = peak - weights.

Uses `nvidia-smi` instead of `torch.cuda.max_memory_allocated()` so
we capture the worker process's memory (vllm-omni runs the model in
a separate spawn-mode worker). See `_gpu_memory_gib`.
"""
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform

generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
torch.accelerator.reset_peak_memory_stats()
weights_mem = _gpu_memory_gib()
start = time.perf_counter()
outputs = omni.generate(
{"prompt": prompt},
Expand All @@ -174,23 +247,27 @@ def _generate_image(omni, args, prompt, seed):
),
)
elapsed = time.perf_counter() - start
peak_mem = torch.accelerator.max_memory_allocated() / (1024**3)
peak_mem = _gpu_memory_gib()
activations_mem = max(peak_mem - weights_mem, 0.0)

req_out = OmniRequestOutput.unwrap_result(outputs)
if not req_out.images:
raise ValueError("Could not extract image output from result.")
img = req_out.images[0]
return img, elapsed, peak_mem
return img, elapsed, peak_mem, weights_mem, activations_mem


def _generate_video(omni, args, prompt, seed):
"""Generate a video and return (np.ndarray [F,H,W,C], time_seconds, memory_gib)."""
"""Generate a video; returns (np.ndarray [F,H,W,C], elapsed_s, peak_gib, weights_gib, activations_gib).

See `_generate_image` for what each memory number means.
"""
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform

generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
torch.accelerator.reset_peak_memory_stats()
weights_mem = _gpu_memory_gib()
start = time.perf_counter()
outputs = omni.generate(
{"prompt": prompt, "negative_prompt": ""},
Expand All @@ -204,7 +281,8 @@ def _generate_video(omni, args, prompt, seed):
),
)
elapsed = time.perf_counter() - start
peak_mem = torch.accelerator.max_memory_allocated() / (1024**3)
peak_mem = _gpu_memory_gib()
activations_mem = max(peak_mem - weights_mem, 0.0)

first = outputs[0]
if hasattr(first, "request_output") and isinstance(first.request_output, list):
Expand Down Expand Up @@ -241,7 +319,7 @@ def _generate_video(omni, args, prompt, seed):
if frames_array.ndim == 5:
frames_array = frames_array[0]

return frames_array, elapsed, peak_mem
return frames_array, elapsed, peak_mem, weights_mem, activations_mem


def _free_gpu_memory():
Expand Down Expand Up @@ -274,20 +352,28 @@ def run_benchmark(args):
print("\n" + "=" * 60)
print("Running BF16 baseline...")
print("=" * 60)
bl_kwargs = _build_omni_kwargs(args, quantization=None)
# `--baseline-model` lets offline-quantized methods (SVDQuant etc.)
# point the baseline at a separate BF16 pipeline tree, while `--model`
# points at the quantized checkpoint.
baseline_model = getattr(args, "baseline_model", None) or args.model
bl_kwargs = _build_omni_kwargs(args, quantization=None, model_override=baseline_model)
omni_bl = Omni(**bl_kwargs)

baseline_outputs = {} # prompt -> (output, time, mem)
baseline_outputs = {} # prompt -> (output, time, peak_mem, weights_mem, activations_mem)
for prompt in prompts:
print(f" Generating: {prompt[:60]}...")
print(f" Generating: {prompt[:60]}...", flush=True)
if is_video:
out, t, mem = _generate_video(omni_bl, args, prompt, seed)
out, t, mem, weights, acts = _generate_video(omni_bl, args, prompt, seed)
else:
out, t, mem = _generate_image(omni_bl, args, prompt, seed)
baseline_outputs[prompt] = (out, t, mem)
out, t, mem, weights, acts = _generate_image(omni_bl, args, prompt, seed)
baseline_outputs[prompt] = (out, t, mem, weights, acts)
print(f" -> {t:.2f}s weights={weights:.2f}GiB peak={mem:.2f}GiB", flush=True)

bl_avg_time = np.mean([v[1] for v in baseline_outputs.values()])
bl_mem = baseline_outputs[prompts[0]][2] # use first prompt's memory
# First prompt's memory snapshot is canonical (matches PR #1470 convention).
bl_mem = baseline_outputs[prompts[0]][2]
bl_weights = baseline_outputs[prompts[0]][3]
bl_acts = baseline_outputs[prompts[0]][4]
omni_bl.shutdown()
del omni_bl
_free_gpu_memory()
Expand Down Expand Up @@ -321,15 +407,18 @@ def run_benchmark(args):

qt_outputs = {}
for prompt in prompts:
print(f" Generating: {prompt[:60]}...")
print(f" Generating: {prompt[:60]}...", flush=True)
if is_video:
out, t, mem = _generate_video(omni_qt, args, prompt, seed)
out, t, mem, weights, acts = _generate_video(omni_qt, args, prompt, seed)
else:
out, t, mem = _generate_image(omni_qt, args, prompt, seed)
qt_outputs[prompt] = (out, t, mem)
out, t, mem, weights, acts = _generate_image(omni_qt, args, prompt, seed)
qt_outputs[prompt] = (out, t, mem, weights, acts)
print(f" -> {t:.2f}s weights={weights:.2f}GiB peak={mem:.2f}GiB", flush=True)

qt_avg_time = np.mean([v[1] for v in qt_outputs.values()])
qt_mem = qt_outputs[prompts[0]][2]
qt_weights = qt_outputs[prompts[0]][3]
qt_acts = qt_outputs[prompts[0]][4]
omni_qt.shutdown()
del omni_qt
_free_gpu_memory()
Expand Down Expand Up @@ -359,14 +448,16 @@ def run_benchmark(args):

mean_lpips = np.mean([p["lpips"] for p in per_prompt])
speedup = bl_avg_time / qt_avg_time if qt_avg_time > 0 else float("inf")
mem_reduction = (bl_mem - qt_mem) / bl_mem * 100
mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 if bl_mem > 0 else 0.0

all_results.append(
{
"config": config_label,
"avg_time": qt_avg_time,
"speedup": speedup,
"memory_gib": qt_mem,
"weights_gib": qt_weights,
"activations_gib": qt_acts,
"mem_reduction_pct": mem_reduction,
"mean_lpips": mean_lpips,
"per_prompt": per_prompt,
Expand Down Expand Up @@ -404,6 +495,30 @@ def run_benchmark(args):
lines.append("> LPIPS < 0.01 = imperceptible, > 0.1 = clearly noticeable.")
lines.append("")

# Memory profiling table (mirrors PR #1470 layout)
tp = args.tensor_parallel_size
lines.append("### Memory Profiling")
lines.append("")
lines.append(
f"First-prompt snapshot at {args.height}x{args.width}, "
f"{args.num_inference_steps} steps. "
"Weights = `memory_allocated()` before `generate()`; "
"Peak = `max_memory_allocated()` during `generate()`; "
"Activations = Peak − Weights."
)
lines.append("")
lines.append("| Config | Weights | Activations | Peak | Total Reduction |")
lines.append("|--------|---------|-------------|------|-----------------|")
lines.append(f"| BF16, TP={tp} | {bl_weights:.2f} GiB | {bl_acts:.2f} GiB | {bl_mem:.2f} GiB | — |")
for r in all_results:
reduction_pct = (bl_mem - r["memory_gib"]) / bl_mem * 100 if bl_mem > 0 else 0.0
lines.append(
f"| {r['config']}, TP={tp} | {r['weights_gib']:.2f} GiB "
f"| {r['activations_gib']:.2f} GiB | {r['memory_gib']:.2f} GiB "
f"| **{reduction_pct:.0f}%** |"
)
lines.append("")

# Per-prompt table
if len(prompts) > 1:
lines.append("### Per-Prompt LPIPS")
Expand Down Expand Up @@ -442,6 +557,16 @@ def parse_args():
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--model", required=True, help="Model name or local path.")
parser.add_argument(
"--baseline-model",
default=None,
help=(
"Optional BF16 model path/name for the baseline run. Defaults to "
"`--model`. Use this when the quantized checkpoint is a separate "
"on-disk pipeline tree (e.g. SVDQuant, MXFP8 offline) rather than "
"an online-quantized variant of the same model."
),
)
parser.add_argument(
"--task",
default="t2i",
Expand All @@ -452,7 +577,13 @@ def parse_args():
"--quantization",
nargs="+",
required=True,
help="One or more quantization methods to benchmark (e.g. fp8 int8 bitsandbytes).",
help=(
"One or more quantization methods to benchmark (e.g. fp8 int8 "
"bitsandbytes). Use the sentinel `auto` for offline-quantized "
"checkpoints (SVDQuant etc.) where the method + per-layer skip "
'list are baked into `transformer/config.json["quantization_config"]` '
"— the bench will not override it and the on-disk config drives the run."
),
)
parser.add_argument(
"--prompts",
Expand Down
38 changes: 37 additions & 1 deletion examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,42 @@
from vllm_omni.platforms import current_omni_platform


def _resolve_quantization_label(cli_quantization: str | None, model: str) -> str:
"""Label that reflects what the engine will actually run, not just CLI args.

`--quantization` overrides whatever is on disk. Otherwise vllm-omni
auto-detects from `transformer/config.json["quantization_config"]`
(see `OmniDiffusionConfig._propagate_quantization_from_tf_config`).
Mirror that lookup so the printed banner doesn't say "None (BF16)"
for a checkpoint that's actually going to load quantized.
"""
if cli_quantization:
return cli_quantization
try:
from vllm.transformers_utils.config import get_hf_file_to_dict

cfg = get_hf_file_to_dict("transformer/config.json", model)
except Exception:
cfg = None
if isinstance(cfg, dict):
qc = cfg.get("quantization_config")
if isinstance(qc, dict):
method = qc.get("quant_method") or qc.get("method")
if method == "component" or (
method is None
and any(isinstance(v, dict) and (v.get("quant_method") or v.get("method")) for v in qc.values())
):
default = qc.get("default")
if isinstance(default, dict):
inner = default.get("quant_method") or default.get("method")
if inner:
return f"{inner} (per-component, from checkpoint)"
return "component (from checkpoint)"
if method:
return f"{method} (from checkpoint)"
return "None (BF16)"


def is_nextstep_model(model_name: str) -> bool:
"""Check if the model is a NextStep model by reading its config."""
from vllm.transformers_utils.config import get_hf_file_to_dict
Expand Down Expand Up @@ -412,7 +448,7 @@ def main():
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}")
print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
print(f" Quantization: {_resolve_quantization_label(args.quantization, args.model)}")
if ignored_layers:
print(f" Ignored layers: {ignored_layers}")
print(
Expand Down
Loading
Loading