diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md index 256698752a..06845563ab 100644 --- a/docs/user_guide/diffusion/lora.md +++ b/docs/user_guide/diffusion/lora.md @@ -1,14 +1,19 @@ # LoRA (Low-Rank Adaptation) Guide -LoRA (Low-Rank Adaptation) enables fine-tuning diffusion models by adding trainable low-rank matrices to existing model weights. vLLM-Omni currently supports PEFT-style LoRA adapters, allowing you to customize model behavior without modifying the base model weights. +LoRA (Low-Rank Adaptation) enables fine-tuning diffusion models by adding trainable low-rank matrices to existing model weights. vLLM-Omni supports PEFT-style LoRA adapters, allowing you to customize model behavior without modifying the base model weights. ## Overview -LoRA adapters are lightweight, model-specific fine-tuning weights that can be dynamically loaded and applied to diffusion models. vLLM-Omni uses a unified LoRA handling mechanism similar to vLLM with LRU cache management. +vLLM-Omni exposes two complementary LoRA flows for diffusion models: + +1. **Init-time LoRA**: a single adapter is pre-loaded when `Omni` starts and is applied to every request. Lowest runtime overhead; best when all requests should share the same adapter. +2. **Per-request LoRA**: zero or more adapters are attached to each request via `sampling_params.lora_requests`. Supports switching adapters between requests and composing multiple adapters in a single forward pass (multi-LoRA). + +Adapters are managed by an LRU cache so repeated activations avoid redundant weight reloads. ## LoRA Adapter Format -LoRA adapters must be in **PEFT (Parameter-Efficient Fine-Tuning)** format. A typical LoRA adapter directory structure: +LoRA adapters must be in **PEFT (Parameter-Efficient Fine-Tuning)** format. A typical adapter directory: ``` lora_adapter/ @@ -16,45 +21,166 @@ lora_adapter/ └── adapter_model.safetensors ``` -The `adapter_config.json` file contains metadata about the LoRA adapter, including: +`adapter_config.json` contains: - `r`: LoRA rank - `lora_alpha`: LoRA alpha scaling factor -- `target_modules`: List of module names to apply LoRA to +- `target_modules`: list of module names the adapter applies to + +!!! note "Server-side Path Requirement" + The LoRA adapter path must be readable on the **server** machine. If your client and server are on different hosts, ensure the adapter is accessible via a shared mount or copied to the server. -## Quick Start -### Offline Inference +## Init-time LoRA -#### Pre-loaded LoRA +### How It Works -Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests: +Passing `lora_path` to `Omni(...)` instructs the engine to register a single adapter at startup and activate it as the only adapter for every request. The adapter occupies one slot of the LoRA cache for the lifetime of the process. + +### Usage ```python from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="Tongyi-MAI/Z-Image-Turbo", + lora_path="/path/to/lora_adapter", + lora_scale=1.0, # optional, default 1.0 +) + +outputs = omni.generate( + "A piece of cheesecake", + OmniDiffusionSamplingParams(height=1024, width=1024, num_inference_steps=9), +) +images = outputs[0].request_output.images +``` + +The CLI wrapper `examples/offline_inference/text_to_image/text_to_image.py` exposes these two kwargs as `--lora-path` and `--lora-scale`: + +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-scale 1.0 \ + --output outputs/cheesecake.png +``` + +### Limitations + +- Exactly one adapter, chosen at init. The adapter cannot be swapped or disabled for individual requests — restart `Omni` to change it. +- Mutually exclusive with `--lora-paths` in the example CLI. Use per-request LoRA when you need different adapters on different requests. + + +## Per-request LoRA + +### How It Works + +Each request carries its own adapter set via `OmniDiffusionSamplingParams`: + +```python +sampling_params = OmniDiffusionSamplingParams( + ..., + lora_requests=[req_a, req_b], # list of LoRARequest + lora_scales=[1.0, 0.5], # same length as lora_requests +) +``` + +- `lora_requests=[]` (or omitted) → no LoRA applied to this request. +- `lora_requests=[req]` → single adapter at the given scale. +- `lora_requests=[req_a, req_b, ...]` → multi-LoRA: all listed adapters are activated simultaneously, each in its own cache slot, and their deltas are summed during the forward pass. + +The cache is sized by `max_loras` (defaults to 1). Set `Omni(..., max_loras=N)` when you plan to activate up to `N` adapters concurrently — requests exceeding this limit are rejected. The example CLI at `examples/offline_inference/text_to_image/text_to_image.py` auto-sizes this to `max(len(--lora-paths), 1)` when `--max-loras` is omitted. + +### Scale Semantics + +- `lora_scales[i]` multiplies adapter `i`'s contribution to the output delta. +- `lora_scales[i] == 0.0` is a registered-but-inactive slot: the adapter remains in the cache but contributes nothing this forward pass. This is distinct from omitting the adapter from `lora_requests`, which releases the slot. +- When `lora_requests` is set and `lora_scales` is omitted, every adapter defaults to scale `1.0`. + +### Usage + +**Single adapter (per-request):** + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id -lora_path="/path/to/lora_adapter" +omni = Omni(model="Tongyi-MAI/Z-Image-Turbo", max_loras=1) -omni = Omni( - model="stabilityai/stable-diffusion-3.5-medium", - lora_path=lora_path +req = LoRARequest( + lora_name="style_a", + lora_int_id=stable_lora_int_id("/path/to/style_a"), + lora_path="/path/to/style_a", ) -lora_request = LoRARequest( - lora_name="preloaded", - lora_int_id=1, - lora_path=lora_path +outputs = omni.generate( + "A piece of cheesecake", + OmniDiffusionSamplingParams( + height=1024, + width=1024, + num_inference_steps=9, + lora_requests=[req], + lora_scales=[1.0], + ), ) +``` + +**Multi-LoRA composition:** + +```python +omni = Omni(model="Tongyi-MAI/Z-Image-Turbo", max_loras=2) + +req_a = LoRARequest(lora_name="style_a", lora_int_id=stable_lora_int_id("/lora/a"), lora_path="/lora/a") +req_b = LoRARequest(lora_name="style_b", lora_int_id=stable_lora_int_id("/lora/b"), lora_path="/lora/b") outputs = omni.generate( - prompt="A piece of cheesecake", - lora_request=lora_request, - lora_scale=2.0, # optional arg, default 1.0 + "A piece of cheesecake", + OmniDiffusionSamplingParams( + height=1024, + width=1024, + num_inference_steps=9, + lora_requests=[req_a, req_b], + lora_scales=[1.0, 0.5], + ), ) ``` -!!! note "Server-side Path Requirement" - The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server. +**Switching adapters between requests** — issue separate `omni.generate(...)` calls with different `OmniDiffusionSamplingParams`. `sampling_params_list` on `omni.generate` is stage-indexed (one entry per pipeline stage) and is shared across all prompts in a batch, so per-prompt adapter variance within a single batch call is not supported through that path. + +**CLI:** + +The example CLI exposes `--lora-paths` + `--lora-scales` for per-request composition, and `--axis` for Cartesian-product XYZ plots that can put any parameter on any of the three axes. Supported axis types are `prompt`, `lora_scale[i]` (i-th `--lora-paths` entry), `guidance_scale`, `num_inference_steps`, and `seed`. X is columns, Y is rows, Z writes one `grid_z{k}.png` per value. + +```bash +# Compose two adapters on one prompt +python examples/offline_inference/text_to_image/text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --prompt "A piece of cheesecake" \ + --lora-paths /lora/a /lora/b \ + --lora-scales 1.0 0.5 \ + --max-loras 2 \ + --output-dir outputs/composed/ + +# 2×2 LoRA-scale grid across 2 prompts (Z): produces grid_z00.png + grid_z01.png +python examples/offline_inference/text_to_image/text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --lora-paths /lora/a /lora/b \ + --max-loras 2 \ + --axis "x=lora_scale[0]:0|1" \ + --axis "y=lora_scale[1]:0|1" \ + --axis "z=prompt:a girl|a cat" \ + --output-dir outputs/axis_test/ +``` + +### Limitations + +- Up to `max_loras` adapters per request. Requests that exceed the limit fail fast before inference. +- All adapters in one request share the same forward pass; they must target compatible modules (scheme enforced by PEFT's `target_modules` field). Adapters targeting disjoint modules compose trivially; overlapping modules add linearly. +- `max_loras` sizes the cache at init and is not resizable at runtime. + ## Wan2.2 LightX2V Offline Assembly diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index c71773972b..9e6c123112 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -74,6 +74,7 @@ python text_to_image.py \ | Argument | Type | Default | Description | | -------- | ---- | ------- | ----------- | | `--prompt` | str | `"a cup of coffee on the table"` | Text description for image generation | +| `--prompts` | str+ | — | Multiple prompts for batched generation. Overrides `--prompt`. Requires `--output-dir`. | | `--seed` | int | `142` | Integer seed for deterministic sampling | | `--negative-prompt` | str | `None` | Negative prompt for classifier-free conditional guidance | | `--cfg-scale` | float | `4.0` | True CFG scale (model-specific guidance strength) | @@ -82,7 +83,8 @@ python text_to_image.py \ | `--num-inference-steps` | int | `50` | Diffusion sampling steps (more steps = higher quality, slower) | | `--height` | int | `1024` | Output image height in pixels | | `--width` | int | `1024` | Output image width in pixels | -| `--output` | str | `"qwen_image_output.png"` | Path to save the generated image | +| `--output` | str | `"qwen_image_output.png"` | Single-image output file path (one prompt, one LoRA combo, one image) | +| `--output-dir` | str | — | Output directory for batch / multi-LoRA / XYZ runs. Files are named `cell_x{x}_y{y}_z{z}_n{n}.png`; `--axis` mode also writes `grid.png` (or `grid_z{k}.png` per Z value). | | `--vae-use-slicing` | flag | off | Enable VAE slicing for memory optimization | | `--vae-use-tiling` | flag | off | Enable VAE tiling for memory optimization | | `--cfg-parallel-size` | int | `1` | Set to `2` to enable CFG Parallel | @@ -90,8 +92,12 @@ python text_to_image.py \ | `--ring-degree` | int | `1` | Ring sequence parallel degree for hybrid Ulysses + Ring inference | | `--ulysses-mode` | str | `"strict"` | Ulysses SP mode: `"strict"` or `"advanced_uaa"` | | `--enable-cpu-offload` | flag | off | Enable CPU offloading for diffusion models | -| `--lora-path` | str | — | Path to PEFT LoRA adapter folder | -| `--lora-scale` | float | `1.0` | Scale factor for LoRA weights | +| `--lora-path` | str | — | Path to a PEFT LoRA adapter folder for init-time static load | +| `--lora-scale` | float | `1.0` | Scale factor for `--lora-path` | +| `--lora-paths` | str+ | — | One or more PEFT LoRA adapter folders for per-request composition. Mutex with `--lora-path`. | +| `--lora-scales` | float+ | `[1.0 ...]` | Per-adapter scales for `--lora-paths` (length must match) | +| `--max-loras` | int | auto | LoRA cache slot count. Defaults to `max(len(--lora-paths), 1)` | +| `--axis` | str (repeatable) | — | XYZ plot axis spec `NAME=TYPE:v1\|v2\|...` where NAME ∈ `{x,y,z}` and TYPE ∈ `{prompt, lora_scale[i], guidance_scale, num_inference_steps, seed}`. Cartesian product of axes defines cells; X=cols, Y=rows, Z produces one `grid_z{k}.png` per value. Repeat up to 3 times. | | `--use-system-prompt` | str | `None` | System prompt preset: `en_unified`, `en_vanilla`, `en_recaption`, `en_think_recaption`, `dynamic`, `None`, or custom text. Recommended: `en_unified`. Only for HunyuanImage-3.0.| | `--system-prompt` | str | `None` | Custom system prompt text. Only used when `--use-system-prompt` is set to `custom`. Only for HunyuanImage-3.0.| @@ -252,7 +258,9 @@ See more examples in the [cfg_parallel user guide](../../../docs/user_guide/para #### LoRA -This example supports PEFT-compatible LoRA (Low-Rank Adaptation) adapters for diffusion models. Pass `--lora-path` to use a LoRA adapter and optionally `--lora-scale` (default `1.0`); omit it to use the base model only. +This example supports PEFT-compatible LoRA (Low-Rank Adaptation) adapters in two modes — see the [LoRA feature guide](../../../docs/user_guide/diffusion/lora.md) for a full description. + +**Init-time LoRA** — one adapter is pre-loaded when `Omni` starts and applied to every generation: ```bash python text_to_image.py \ @@ -263,6 +271,36 @@ python text_to_image.py \ --output output.png ``` +**Per-request LoRA (incl. multi-LoRA composition)** — one or more adapters are attached to each request via `sampling_params.lora_requests`. Size the adapter cache with `--max-loras`: + +```bash +python text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --prompt "A piece of cheesecake" \ + --lora-paths /lora/style_a /lora/style_b \ + --lora-scales 1.0 0.5 \ + --max-loras 2 \ + --output-dir outputs/composed/ +``` + +**XYZ plot** — put any parameter on any axis and take the Cartesian product. Each `--axis` has the form `NAME=TYPE:v1|v2|...` where `NAME` is `x` / `y` / `z` and `TYPE` is one of `prompt`, `lora_scale[i]` (targets the i-th `--lora-paths` entry), `guidance_scale`, `num_inference_steps`, or `seed`. X/Y compose a 2D grid; Z writes one `grid_z{k}.png` per value. + +```bash +# 2 × 2 scale sweep of two LoRAs, across 2 prompts (Z) +python text_to_image.py \ + --model Tongyi-MAI/Z-Image-Turbo \ + --lora-paths /lora/style_a /lora/style_b \ + --max-loras 2 \ + --axis "x=lora_scale[0]:0|1" \ + --axis "y=lora_scale[1]:0|1" \ + --axis "z=prompt:a girl|a cat" \ + --output-dir outputs/axis_test/ +``` + +Grid cells are labeled with `{adapter}\n{scale}` for LoRA-scale axes and the prompt text for a prompt axis; the Z banner shows the current slice. + +`--lora-path` and `--lora-paths` are mutually exclusive. `--output-dir` is required whenever the script produces more than one image (multiple prompts, any `--axis`, or `--num-images-per-prompt > 1`). + LoRA adapters must be in PEFT format. A typical adapter directory structure: ``` diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index bc18c68591..74e151f02b 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -3,7 +3,10 @@ import argparse import json +import re +import textwrap import time +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -59,6 +62,13 @@ def parse_args() -> argparse.Namespace: help="Path to a YAML file containing stage configurations for Omni.", ) parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") + parser.add_argument( + "--prompts", + nargs="+", + default=None, + help="Multiple prompts for batched generation. Overrides --prompt when set. " + "Each prompt is dispatched as part of a single omni.generate() batch call.", + ) parser.add_argument( "--negative-prompt", default=None, @@ -83,7 +93,16 @@ def parse_args() -> argparse.Namespace: "--output", type=str, default="qwen_image_output.png", - help="Path to save the generated image (PNG).", + help="Path to save the generated image (PNG). Used only in single-output mode " + "(one prompt, one LoRA combo, one image). Ignored when --output-dir is set.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory for batch/XYZ output. Required when there are multiple prompts " + "or --axis is set. Files are saved as cell_x{x}_y{y}_z{z}.png; with --axis a " + "grid.png (or grid_z{k}.png per Z value) is also written.", ) parser.add_argument( "--num-images-per-prompt", @@ -220,13 +239,48 @@ def parse_args() -> argparse.Namespace: "--lora-path", type=str, default=None, - help="Path to LoRA adapter folder (PEFT format). Loaded at initialization and used for generation.", + help="Path to LoRA adapter folder (PEFT format). Init-time static load: the adapter is " + "pre-loaded into the engine cache and applied to every request. Mutually exclusive with --lora-paths.", ) parser.add_argument( "--lora-scale", type=float, default=1.0, - help="Scale factor for LoRA weights (default: 1.0).", + help="Scale factor for --lora-path (default: 1.0).", + ) + parser.add_argument( + "--lora-paths", + nargs="+", + default=None, + help="Multiple LoRA adapter folders (PEFT format) for per-request composition. " + "Each request applies all listed adapters with the matching --lora-scales. " + "Mutually exclusive with --lora-path.", + ) + parser.add_argument( + "--lora-scales", + nargs="+", + type=float, + default=None, + help="Per-adapter scales for --lora-paths. Length must match --lora-paths; " + "defaults to 1.0 per adapter when omitted.", + ) + parser.add_argument( + "--max-loras", + type=int, + default=None, + help="Maximum number of LoRA slots active simultaneously. Defaults to max(len(--lora-paths), 1).", + ) + parser.add_argument( + "--axis", + action="append", + default=None, + metavar="SPEC", + help="XYZ axis. Repeat up to 3 times. Spec form: NAME=TYPE:v1|v2|v3 where " + "NAME ∈ {x,y,z} and TYPE ∈ {prompt, lora_scale[i], guidance_scale, " + "num_inference_steps, seed}. The Cartesian product of X×Y×Z defines cells: " + "X is columns, Y is rows, Z produces one grid per value (grid_z{k}.png). " + 'Example: --axis "x=lora_scale[0]:0|1" --axis "y=lora_scale[1]:0|1" ' + '--axis "z=prompt:a girl|a cat" yields a 2×2 grid per prompt.', ) parser.add_argument( "--vae-patch-parallel-size", @@ -307,11 +361,231 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() +def _resolve_prompts(args: argparse.Namespace) -> list[str]: + """Return the list of prompts to run. Prefers --prompts; falls back to --prompt.""" + if args.prompts: + return list(args.prompts) + return [args.prompt] + + +def _build_lora_request(path: str) -> LoRARequest: + return LoRARequest( + lora_name=Path(path).stem, + lora_int_id=stable_lora_int_id(path), + lora_path=path, + ) + + +def _resolve_lora( + args: argparse.Namespace, +) -> tuple[list[LoRARequest], list[float], bool]: + """Return (lora_requests, lora_scales, is_per_request) for the default cell. + + ``is_per_request`` is True when --lora-paths is given (per-request LoRA), + False when --lora-path (init-time) or no LoRA is used. + """ + if args.lora_path and args.lora_paths: + raise ValueError("--lora-path and --lora-paths are mutually exclusive.") + + if not args.lora_paths: + return [], [], False + + lora_paths = list(args.lora_paths) + lora_scales = list(args.lora_scales) if args.lora_scales is not None else [1.0] * len(lora_paths) + if len(lora_paths) != len(lora_scales): + raise ValueError( + f"--lora-paths ({len(lora_paths)}) and --lora-scales ({len(lora_scales)}) must have the same length." + ) + requests = [_build_lora_request(p) for p in lora_paths] + return requests, lora_scales, True + + +_LORA_SCALE_TYPE_RE = re.compile(r"^lora_scale\[(\d+)\]$") +_AXIS_TYPES = {"prompt", "guidance_scale", "num_inference_steps", "seed"} + + +@dataclass +class _Axis: + name: str # 'x' | 'y' | 'z' + type: str + values: list[str] # raw strings; converted per type when applied + + +def _parse_axes(specs: list[str] | None) -> dict[str, _Axis]: + """Parse repeated --axis specs into a dict keyed by axis name.""" + if not specs: + return {} + axes: dict[str, _Axis] = {} + for spec in specs: + name_part, sep, rest = spec.partition("=") + if not sep: + raise ValueError(f"--axis spec missing '=': {spec!r}") + type_part, sep, values_part = rest.partition(":") + if not sep: + raise ValueError(f"--axis spec missing ':' between type and values: {spec!r}") + name = name_part.strip().lower() + if name not in ("x", "y", "z"): + raise ValueError(f"--axis name must be x, y, or z; got {name!r}") + if name in axes: + raise ValueError(f"--axis {name} specified twice") + atype = type_part.strip() + if atype not in _AXIS_TYPES and not _LORA_SCALE_TYPE_RE.match(atype): + raise ValueError( + f"--axis type {atype!r} unknown. Supported: prompt, lora_scale[i], " + f"guidance_scale, num_inference_steps, seed." + ) + values = [v.strip() for v in values_part.split("|") if v.strip()] + if not values: + raise ValueError(f"--axis {name} has no values: {spec!r}") + axes[name] = _Axis(name=name, type=atype, values=values) + return axes + + +def _axis_label(axis: _Axis, value: str, lora_names: list[str]) -> str: + """Render a short cell-header label. Embeds a newline between name and value + so wide labels (e.g. ``lora_chardesign=1.00``) wrap cleanly in the grid + margin strips; the grid composer honors explicit newlines verbatim. + """ + if axis.type == "prompt": + s = value if len(value) <= 40 else value[:37] + "..." + return s + m = _LORA_SCALE_TYPE_RE.match(axis.type) + if m: + idx = int(m.group(1)) + name = lora_names[idx] if idx < len(lora_names) else f"lora[{idx}]" + return f"{name}\n{float(value):.2f}" + return f"{axis.type}\n{value}" + + +def _apply_axis( + axis: _Axis, + raw_value: str, + cell: dict, + lora_count: int, +) -> None: + """Mutate cell in place by applying a single axis value.""" + t = axis.type + if t == "prompt": + cell["prompt"] = raw_value + return + m = _LORA_SCALE_TYPE_RE.match(t) + if m: + idx = int(m.group(1)) + if idx >= lora_count: + raise ValueError(f"axis lora_scale[{idx}] but only {lora_count} LoRA(s) provided via --lora-paths") + cell["lora_scales"] = list(cell["lora_scales"]) + cell["lora_scales"][idx] = float(raw_value) + return + if t == "guidance_scale": + cell["guidance_scale"] = float(raw_value) + return + if t == "num_inference_steps": + cell["num_inference_steps"] = int(raw_value) + return + if t == "seed": + cell["seed"] = int(raw_value) + return + raise ValueError(f"axis type {t!r} not implemented") + + +def _compose_grid( + results: dict[tuple[int, int], list[Any]], + num_rows: int, + num_cols: int, + row_labels: list[str] | None = None, + col_labels: list[str] | None = None, + title: str | None = None, +) -> Any: + """Stitch the first image of each (row, col) cell into a single grid PIL image. + + Optional ``row_labels`` / ``col_labels`` reserve left/top strips with per-row + and per-column text. Optional ``title`` reserves a narrow banner on top. + """ + from PIL import Image, ImageDraw + + sample = next(iter(results.values()))[0] + cell_w, cell_h = sample.width, sample.height + + col_strip = max(64, cell_h // 10) if col_labels else 0 + row_strip = max(220, cell_w // 4) if row_labels else 0 + title_strip = max(48, cell_h // 14) if title else 0 + + top = title_strip + col_strip + left = row_strip + + grid = Image.new("RGB", (left + cell_w * num_cols, top + cell_h * num_rows), color="white") + draw = ImageDraw.Draw(grid) + + font = _load_label_font(max(18, (col_strip or cell_h // 10) // 3)) + font_row = _load_label_font(max(16, (col_strip or cell_h // 10) // 4)) + font_title = _load_label_font(max(22, title_strip // 2)) if title else font + + if title: + draw.text( + (grid.size[0] // 2, title_strip // 2), + title, + fill="black", + font=font_title, + anchor="mm", + ) + if col_labels: + for c_idx, lbl in enumerate(col_labels): + x = left + c_idx * cell_w + cell_w // 2 + y = title_strip + col_strip // 2 + draw.text((x, y), lbl, fill="black", font=font, anchor="mm", align="center") + if row_labels: + for r_idx, lbl in enumerate(row_labels): + y = top + r_idx * cell_h + cell_h // 2 + # Honor explicit newlines from axis labels; otherwise soft-wrap long text. + rendered = lbl if "\n" in lbl else ("\n".join(textwrap.wrap(lbl, width=18)) or lbl) + draw.text((row_strip // 2, y), rendered, fill="black", font=font_row, anchor="mm", align="center") + + for (r, c), imgs in results.items(): + grid.paste(imgs[0], (left + c * cell_w, top + r * cell_h)) + return grid + + +def _load_label_font(size: int): + """Return a readable TrueType font if available, otherwise PIL's default bitmap font.""" + from PIL import ImageFont + + for candidate in ( + "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", + "DejaVuSans-Bold.ttf", + ): + try: + return ImageFont.truetype(candidate, size=size) + except OSError: + continue + return ImageFont.load_default() + + def main(): args = parse_args() - generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) use_nextstep = is_nextstep_model(args.model) + prompts = _resolve_prompts(args) + lora_requests, lora_scales, lora_is_per_request = _resolve_lora(args) + axes = _parse_axes(args.axis) + + if axes and len(prompts) > 1: + raise ValueError( + "--axis cannot be combined with multi-prompt input; put prompts on the prompt axis instead, " + 'e.g. --axis "z=prompt:a|b".' + ) + requires_output_dir = len(prompts) > 1 or args.num_images_per_prompt > 1 or bool(axes) + if requires_output_dir and not args.output_dir: + raise ValueError( + "--output-dir is required when running multiple prompts, multiple images per prompt, " + "or --axis. Single --output is only valid for one image." + ) + if args.max_loras is not None and args.lora_paths and args.max_loras < len(args.lora_paths): + raise ValueError( + f"--max-loras ({args.max_loras}) is smaller than len(--lora-paths) ({len(args.lora_paths)}). " + "Composition needs one slot per adapter — raise --max-loras or remove it to auto-size." + ) + cache_config = None cache_backend = args.cache_backend @@ -359,7 +633,14 @@ def main(): lora_args: dict[str, Any] = {} if args.lora_path: lora_args["lora_path"] = args.lora_path - print(f"Using LoRA from: {args.lora_path}") + print(f"Using init-time LoRA from: {args.lora_path}") + + # max_loras sizes the adapter cache at init. Per-request combos may load + # several adapters simultaneously, so default to max(len(lora_paths), 1). + if args.max_loras is not None: + lora_args["max_loras"] = args.max_loras + elif args.lora_paths: + lora_args["max_loras"] = max(len(args.lora_paths), 1) # Build quantization kwargs: use quantization_config dict when # ignored_layers is specified so the list flows through OmniDiffusionConfig @@ -430,22 +711,18 @@ def main(): print(f" CPU offload: {args.enable_cpu_offload}; CPU Layerwise Offload: {args.enable_layerwise_offload}") print(f" Image size: {args.width}x{args.height}") if args.lora_path: - print(f" LoRA: scale={args.lora_scale}") + print(f" Init-time LoRA: scale={args.lora_scale}") + if lora_is_per_request: + print(f" Per-request LoRA ({len(lora_requests)}):") + for idx, (req, scale) in enumerate(zip(lora_requests, lora_scales)): + print(f" [{idx}] {req.lora_name} scale={scale}") + print(f" Prompts: {len(prompts)}") + if axes: + print(f" Axes: {', '.join(f'{a.name}={a.type}:{len(a.values)} values' for a in axes.values())}") if args.stage_configs_path: print(f" stage-configs-path: {args.stage_configs_path}") print(f"{'=' * 60}\n") - # Build LoRA request when --lora-path is set - lora_request = None - if args.lora_path: - lora_request_id = stable_lora_int_id(args.lora_path) - lora_request = LoRARequest( - lora_name=Path(args.lora_path).stem, - lora_int_id=lora_request_id, - lora_path=args.lora_path, - ) - - generation_start = time.perf_counter() extra_args = { "timesteps_shift": args.timesteps_shift, "cfg_schedule": args.cfg_schedule, @@ -453,27 +730,73 @@ def main(): "use_system_prompt": args.use_system_prompt, "system_prompt": args.system_prompt, } - if lora_request: - extra_args["lora_request"] = lora_request - extra_args["lora_scale"] = args.lora_scale - - outputs = omni.generate( - { - "prompt": args.prompt, - "negative_prompt": args.negative_prompt, - }, - OmniDiffusionSamplingParams( + + def _run_cell(prompt: str, cell: dict) -> list[Any]: + gen = torch.Generator(device=current_omni_platform.device_type).manual_seed(cell["seed"]) + sp = OmniDiffusionSamplingParams( height=args.height, width=args.width, - generator=generator, + generator=gen, true_cfg_scale=args.cfg_scale, - guidance_scale=args.guidance_scale, + guidance_scale=cell["guidance_scale"], guidance_scale_2=args.guidance_scale_2, - num_inference_steps=args.num_inference_steps, + num_inference_steps=cell["num_inference_steps"], num_outputs_per_prompt=args.num_images_per_prompt, + lora_requests=lora_requests if lora_is_per_request else [], + lora_scales=cell["lora_scales"] if lora_is_per_request else [], extra_args=extra_args, - ), - ) + ) + outs = omni.generate([{"prompt": prompt, "negative_prompt": args.negative_prompt}], sp) + if not outs or not getattr(outs[0], "request_output", None): + raise ValueError("Generate returned no request_output") + imgs = outs[0].request_output.images + if not imgs: + raise ValueError("Empty image list from generate") + return imgs + + defaults = { + "prompt": prompts[0], + "lora_scales": list(lora_scales), + "guidance_scale": args.guidance_scale, + "num_inference_steps": args.num_inference_steps, + "seed": args.seed, + } + + # (z_idx, y_idx, x_idx) -> images for one cell. Unused axes collapse to idx 0. + cell_images: dict[tuple[int, int, int], list[Any]] = {} + + generation_start = time.perf_counter() + + if axes: + x_axis, y_axis, z_axis = axes.get("x"), axes.get("y"), axes.get("z") + z_values = z_axis.values if z_axis else [None] + y_values = y_axis.values if y_axis else [None] + x_values = x_axis.values if x_axis else [None] + + total = len(z_values) * len(y_values) * len(x_values) + counter = 0 + for z_idx, z_val in enumerate(z_values): + for y_idx, y_val in enumerate(y_values): + for x_idx, x_val in enumerate(x_values): + cell = dict(defaults) + for ax, raw in ((x_axis, x_val), (y_axis, y_val), (z_axis, z_val)): + if ax is not None: + _apply_axis(ax, raw, cell, len(lora_requests)) + counter += 1 + label = " ".join( + f"{ax.name}={raw}" + for ax, raw in ((x_axis, x_val), (y_axis, y_val), (z_axis, z_val)) + if ax is not None + ) + print(f"[cell {counter}/{total}] {label}") + cell_images[(z_idx, y_idx, x_idx)] = _run_cell(cell["prompt"], cell) + else: + # No axes: generate one image per prompt; cell key (0, p_idx, 0). + cell = dict(defaults) + for p_idx, prompt in enumerate(prompts): + cell["prompt"] = prompt + print(f"[cell {p_idx + 1}/{len(prompts)}] prompt={prompt!r}") + cell_images[(0, p_idx, 0)] = _run_cell(prompt, cell) generation_end = time.perf_counter() generation_time = generation_end - generation_start @@ -498,35 +821,48 @@ def main(): else: print("[Profiler] No valid profiling data returned.") - # omni.generate() returns list[OmniRequestOutput] - if not outputs or len(outputs) == 0: - raise ValueError("No output generated from omni.generate()") - logger.info(f"Outputs: {outputs}") - - first_output = outputs[0] - if not hasattr(first_output, "request_output") or not first_output.request_output: - raise ValueError("No request_output found in OmniRequestOutput") - - req_out = first_output.request_output - if not hasattr(req_out, "images"): - raise ValueError("Invalid request_output structure or missing 'images'.") - - images = req_out.images - if not images: - raise ValueError("No images found in request_output") - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - suffix = output_path.suffix or ".png" - stem = output_path.stem or "qwen_image_output" - if len(images) <= 1: - images[0].save(output_path) - print(f"Saved generated image to {output_path}") + logger.info("Produced %d cells", len(cell_images)) + + if args.output_dir: + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + for (z_idx, y_idx, x_idx), imgs in cell_images.items(): + for n_idx, img in enumerate(imgs): + save_path = out_dir / f"cell_x{x_idx:02d}_y{y_idx:02d}_z{z_idx:02d}_n{n_idx:02d}.png" + img.save(save_path) + print(f"Saved {save_path}") + + if axes: + x_axis, y_axis, z_axis = axes.get("x"), axes.get("y"), axes.get("z") + lora_names = [req.lora_name for req in lora_requests] + col_labels = [_axis_label(x_axis, v, lora_names) for v in x_axis.values] if x_axis else None + row_labels = [_axis_label(y_axis, v, lora_names) for v in y_axis.values] if y_axis else None + num_cols = len(x_axis.values) if x_axis else 1 + num_rows = len(y_axis.values) if y_axis else 1 + + z_values = z_axis.values if z_axis else [None] + for z_idx, z_val in enumerate(z_values): + slice_cells = {(y, x): imgs for (z, y, x), imgs in cell_images.items() if z == z_idx} + title = f"Z: {_axis_label(z_axis, z_val, lora_names)}" if z_axis else None + grid = _compose_grid( + slice_cells, + num_rows=num_rows, + num_cols=num_cols, + row_labels=row_labels, + col_labels=col_labels, + title=title, + ) + fname = f"grid_z{z_idx:02d}.png" if z_axis else "grid.png" + grid_path = out_dir / fname + grid.save(grid_path) + print(f"Saved grid to {grid_path}") else: - for idx, img in enumerate(images): - save_path = output_path.parent / f"{stem}_{idx}{suffix}" - img.save(save_path) - print(f"Saved generated image to {save_path}") + # Single-output mode: exactly one cell, one image. + only_images = next(iter(cell_images.values())) + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + only_images[0].save(output_path) + print(f"Saved generated image to {output_path}") if __name__ == "__main__": diff --git a/tests/diffusion/lora/helpers.py b/tests/diffusion/lora/helpers.py index 8b9b1ef4d2..31d282ddb6 100644 --- a/tests/diffusion/lora/helpers.py +++ b/tests/diffusion/lora/helpers.py @@ -16,25 +16,24 @@ def __init__(self): class DummyBaseLayerWithLoRA(torch.nn.Module): - """Fake LoRA wrapper that records set/reset/create calls.""" + """Fake LoRA wrapper that records set/reset/create calls per slot.""" def __init__(self, base_layer: torch.nn.Module): super().__init__() self.base_layer = base_layer self.set_calls: list[ - tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] + tuple[int, list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] ] = [] - self.reset_calls: int = 0 + self.reset_calls: list[int] = [] self.create_calls: int = 0 + self._n_active_adapters: int = 0 def set_lora(self, index: int, lora_a, lora_b): - assert index == 0 - self.set_calls.append((lora_a, lora_b)) + self.set_calls.append((index, lora_a, lora_b)) def reset_lora(self, index: int): - assert index == 0 - self.reset_calls += 1 + self.reset_calls.append(index) def create_lora_weights(self, max_loras, lora_config, model_config): self.create_calls += 1 diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py index 785f5d8421..4338fc67f4 100644 --- a/tests/diffusion/lora/test_lora_manager.py +++ b/tests/diffusion/lora/test_lora_manager.py @@ -23,18 +23,18 @@ class _DummyLoRALayer: def __init__(self, n_slices: int, output_slices: tuple[int, ...]): self.n_slices = n_slices self.output_slices = output_slices + # Keyed by slot index self.set_calls: list[ - tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] + tuple[int, list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] ] = [] - self.reset_calls: int = 0 + self.reset_calls: list[int] = [] + self._n_active_adapters: int = 0 def set_lora(self, index: int, lora_a, lora_b): - assert index == 0 - self.set_calls.append((lora_a, lora_b)) + self.set_calls.append((index, lora_a, lora_b)) def reset_lora(self, index: int): - assert index == 0 - self.reset_calls += 1 + self.reset_calls.append(index) # Aliases for backward compatibility within this file @@ -52,24 +52,24 @@ def __init__(self): class _DummyLM(torch.nn.Module): """LoRA enabled wrapper for _DummyPipeline.""" - def __init__(self, rank: int): + def __init__(self, rank: int, lora_a_val: float = 1.0, lora_b_val: float = 1.0): super().__init__() self.transformer = torch.nn.Module() self.transformer.foo = _DummyBaseLayerWithLoRA(_FakeLinearBase()) self.rank = rank - self.loras = self.get_lora_modules() + self.loras = self.get_lora_modules(lora_a_val, lora_b_val) - def get_lora_modules(self): - return {"transformer.foo": self._get_initial_lora(self.rank)} + def get_lora_modules(self, lora_a_val: float = 1.0, lora_b_val: float = 1.0): + return {"transformer.foo": self._get_initial_lora(self.rank, lora_a_val, lora_b_val)} def get_lora(self, k: str) -> LoRALayerWeights: """Get the unscaled LoRA weights for transformer.foo""" return self.loras[k] - def _get_initial_lora(self, rank: int) -> LoRALayerWeights: + def _get_initial_lora(self, rank: int, lora_a_val: float = 1.0, lora_b_val: float = 1.0) -> LoRALayerWeights: """Initializes a dummy LoRA for the current rank.""" - A = torch.ones((rank, 4)) - B = torch.ones((4, rank)) + A = torch.ones((rank, 4)) * lora_a_val + B = torch.ones((4, rank)) * lora_b_val return LoRALayerWeights( module_name="foo", rank=rank, @@ -218,11 +218,12 @@ def test_lora_manager_activates_fused_lora_on_packed_layer(): )() } - manager._activate_adapter(7, 0.5) + manager._activate_adapters([7], [0.5]) - assert packed_layer.reset_calls == 0 - assert len(packed_layer.set_calls) == 1 - lora_a_list, lora_b_list = packed_layer.set_calls[0] + # Filter set_calls for slot 0 + slot0_sets = [(a, b) for idx, a, b in packed_layer.set_calls if idx == 0] + assert len(slot0_sets) == 1 + lora_a_list, lora_b_list = slot0_sets[0] assert isinstance(lora_a_list, list) assert isinstance(lora_b_list, list) assert len(lora_a_list) == 3 @@ -266,11 +267,11 @@ def test_lora_manager_activates_packed_lora_from_sublayers(): 1: type("LM", (), {"id": 1, "loras": loras, "get_lora": lambda self, k: self.loras.get(k)})() } - manager._activate_adapter(1, scale=2.0) + manager._activate_adapters([1], [2.0]) - assert packed_layer.reset_calls == 0 - assert len(packed_layer.set_calls) == 1 - lora_a_list, lora_b_list = packed_layer.set_calls[0] + slot0_sets = [(a, b) for idx, a, b in packed_layer.set_calls if idx == 0] + assert len(slot0_sets) == 1 + lora_a_list, lora_b_list = slot0_sets[0] assert isinstance(lora_a_list, list) assert isinstance(lora_b_list, list) assert len(lora_a_list) == 3 @@ -304,19 +305,19 @@ def _fake_load(_req: LoRARequest): monkeypatch.setattr(manager, "_load_adapter", _fake_load) monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) - monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id, scale: None) + monkeypatch.setattr(manager, "_activate_adapters", lambda _ids, _scales: None) req1 = _dummy_lora_request(1) req2 = _dummy_lora_request(2) req3 = _dummy_lora_request(3) - manager.set_active_adapter(req1, lora_scale=1.0) - manager.set_active_adapter(req2, lora_scale=1.0) + manager.set_active_adapters([req1], [1.0]) + manager.set_active_adapters([req2], [1.0]) # Touch adapter 1 so adapter 2 becomes LRU. - manager.set_active_adapter(req1, lora_scale=1.0) + manager.set_active_adapters([req1], [1.0]) - manager.set_active_adapter(req3, lora_scale=1.0) + manager.set_active_adapters([req3], [1.0]) assert set(manager.list_adapters()) == {1, 3} @@ -336,13 +337,13 @@ def _fake_load(_req: LoRARequest): monkeypatch.setattr(manager, "_load_adapter", _fake_load) monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) - monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id, scale: None) + monkeypatch.setattr(manager, "_activate_adapters", lambda _ids, _scales: None) - manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) + manager.set_active_adapters([_dummy_lora_request(1)], [1.0]) assert manager.pin_adapter(1) - manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) - manager.set_active_adapter(_dummy_lora_request(3), lora_scale=1.0) + manager.set_active_adapters([_dummy_lora_request(2)], [1.0]) + manager.set_active_adapters([_dummy_lora_request(3)], [1.0]) assert set(manager.list_adapters()) == {1, 3} @@ -362,10 +363,10 @@ def _fake_load(_req: LoRARequest): monkeypatch.setattr(manager, "_load_adapter", _fake_load) monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) - monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id, scale: None) + monkeypatch.setattr(manager, "_activate_adapters", lambda _ids, _scales: None) - manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) - manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) + manager.set_active_adapters([_dummy_lora_request(1)], [1.0]) + manager.set_active_adapters([_dummy_lora_request(2)], [1.0]) assert manager.pin_adapter(1) assert manager.pin_adapter(2) @@ -408,17 +409,19 @@ def _fake_load(_req: LoRARequest): manager._lora_modules = {"transformer.foo": lora_model.transformer.foo} # After the first scale, all B values should go from 1 -> scale_1 - manager.set_active_adapter(req1, lora_scale=scale_1) - assert len(lora_model.transformer.foo.set_calls) == 1 - lora_a, lora_b = lora_model.transformer.foo.set_calls[0] + manager.set_active_adapters([req1], [scale_1]) + slot0_sets = [(a, b) for idx, a, b in lora_model.transformer.foo.set_calls if idx == 0] + assert len(slot0_sets) == 1 + lora_a, lora_b = slot0_sets[0] assert torch.all(lora_a == 1) assert torch.all(lora_b == scale_1) # After the second scale, all B values should go from 1 -> scale_2 - manager.set_active_adapter(req1, lora_scale=scale_2) - assert len(lora_model.transformer.foo.set_calls) == 2 + manager.set_active_adapters([req1], [scale_2]) + slot0_sets = [(a, b) for idx, a, b in lora_model.transformer.foo.set_calls if idx == 0] + assert len(slot0_sets) == 2 - lora_a, lora_b = lora_model.transformer.foo.set_calls[1] + lora_a, lora_b = slot0_sets[1] assert torch.all(lora_a == 1) assert torch.all(lora_b == scale_2) @@ -454,10 +457,11 @@ def _fake_load(_req: LoRARequest): manager._lora_modules = {"transformer.foo": lora_model.transformer.foo} # Activate adapter with initial scale - manager.set_active_adapter(req1, lora_scale=initial_scale) + manager.set_active_adapters([req1], [initial_scale]) assert lora_model.transformer.foo.create_calls == 0 - assert len(lora_model.transformer.foo.set_calls) == 1 - lora_a, lora_b = lora_model.transformer.foo.set_calls[0] + slot0_sets = [(a, b) for idx, a, b in lora_model.transformer.foo.set_calls if idx == 0] + assert len(slot0_sets) == 1 + lora_a, lora_b = slot0_sets[0] assert torch.all(lora_a == 1) assert torch.all(lora_b == initial_scale) @@ -465,27 +469,18 @@ def _fake_load(_req: LoRARequest): manager._ensure_max_lora_rank(8) # Ensure we actually took the rank expansion path, which recreates - # and sets the weight buffets, but that the scale didn't change + # and sets the weight buffers, but that the scale didn't change assert lora_model.transformer.foo.create_calls == 1 - assert len(lora_model.transformer.foo.set_calls) == 2 - lora_a, lora_b = lora_model.transformer.foo.set_calls[1] + slot0_sets = [(a, b) for idx, a, b in lora_model.transformer.foo.set_calls if idx == 0] + assert len(slot0_sets) == 2 + lora_a, lora_b = slot0_sets[1] assert torch.all(lora_a == 1) assert torch.all(lora_b == initial_scale) -def test_scale_keys_are_rounded(): - """Ensure that added adapter scales are rounded to avoid lookup - issues due to precision differences, e.g., computed scales. - """ - manager = DiffusionLoRAManager( - pipeline=_DummyPipeline(), - device=torch.device("cpu"), - dtype=torch.bfloat16, - ) - adapter_id = 1 - # Currently we round keys to 3 decimal places - manager._update_adapter_scale(adapter_id, 0.0031) - assert manager._adapter_scales[adapter_id] == 0.003 +def test_scale_rounding(): + """Ensure that scales are rounded for comparison.""" + assert DiffusionLoRAManager._get_rounded_scale(0.0031) == 0.003 def test_lora_manager_uses_valid_max_rank(monkeypatch): @@ -535,6 +530,218 @@ def _fake_load(_req: LoRARequest): manager.add_adapter(req1) +# ============================================================ +# Multi-LoRA composition tests +# ============================================================ + + +def test_multi_adapter_activation(): + """Verify multiple adapters are set into separate slots.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=3, + max_cached_adapters=3, + ) + + layer = _DummyLoRALayer(n_slices=1, output_slices=(4,)) + manager._lora_modules = {"transformer.foo": layer} + + rank = 2 + adapters = {} + for aid in [1, 2, 3]: + lora = LoRALayerWeights( + module_name="foo", + rank=rank, + lora_alpha=rank, + lora_a=torch.ones((rank, 4)) * aid, + lora_b=torch.ones((4, rank)) * aid, + ) + adapters[aid] = type( + "LM", (), {"id": aid, "loras": {"transformer.foo": lora}, "get_lora": lambda self, k: self.loras.get(k)} + )() + + manager._registered_adapters = adapters + + manager._activate_adapters([1, 2, 3], [0.5, 0.75, 1.0]) + + # Should have 3 set_lora calls (one per adapter) for the single layer + set_by_slot = {idx: (a, b) for idx, a, b in layer.set_calls} + assert 0 in set_by_slot + assert 1 in set_by_slot + assert 2 in set_by_slot + + # Verify weights are correct per slot + a0, b0 = set_by_slot[0] + assert torch.allclose(a0, torch.ones((rank, 4)) * 1) + assert torch.allclose(b0, torch.ones((4, rank)) * 1 * 0.5) + + a1, b1 = set_by_slot[1] + assert torch.allclose(a1, torch.ones((rank, 4)) * 2) + assert torch.allclose(b1, torch.ones((4, rank)) * 2 * 0.75) + + a2, b2 = set_by_slot[2] + assert torch.allclose(a2, torch.ones((rank, 4)) * 3) + assert torch.allclose(b2, torch.ones((4, rank)) * 3 * 1.0) + + assert layer._n_active_adapters == 3 + + +def test_multi_adapter_unused_slots_are_reset(): + """When going from 3 adapters to 1, unused slots should be reset.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=3, + max_cached_adapters=3, + ) + + layer = _DummyLoRALayer(n_slices=1, output_slices=(4,)) + manager._lora_modules = {"transformer.foo": layer} + + rank = 2 + adapters = {} + for aid in [1, 2, 3]: + lora = LoRALayerWeights( + module_name="foo", + rank=rank, + lora_alpha=rank, + lora_a=torch.ones((rank, 4)), + lora_b=torch.ones((4, rank)), + ) + adapters[aid] = type( + "LM", (), {"id": aid, "loras": {"transformer.foo": lora}, "get_lora": lambda self, k: self.loras.get(k)} + )() + + manager._registered_adapters = adapters + + # Activate 3 + manager._activate_adapters([1, 2, 3], [1.0, 1.0, 1.0]) + assert layer._n_active_adapters == 3 + + # Now activate only 1 — slots 1 and 2 should be reset + layer.set_calls.clear() + layer.reset_calls.clear() + + manager._activate_adapters([1], [1.0]) + + assert layer._n_active_adapters == 1 + # Slots 1 and 2 should have been reset + assert 1 in layer.reset_calls + assert 2 in layer.reset_calls + + +def test_multi_adapter_exceeds_max_loras(): + """Requesting more adapters than max_loras should raise ValueError.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=2, + max_cached_adapters=3, + ) + + with pytest.raises(ValueError, match="max_loras"): + manager.set_active_adapters( + [_dummy_lora_request(1), _dummy_lora_request(2), _dummy_lora_request(3)], + [1.0, 1.0, 1.0], + ) + + +def test_multi_adapter_mismatched_lengths(): + """lora_requests and lora_scales must have the same length.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=2, + ) + + with pytest.raises(ValueError, match="same length"): + manager.set_active_adapters( + [_dummy_lora_request(1), _dummy_lora_request(2)], + [1.0], + ) + + +def test_multi_adapter_empty_list_deactivates(): + """Empty request list should deactivate all adapters.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=2, + max_cached_adapters=2, + ) + + layer = _DummyLoRALayer(n_slices=1, output_slices=(4,)) + manager._lora_modules = {"transformer.foo": layer} + + # Activate something first + manager._active_adapter_ids = [1] + manager._active_adapter_scales = [1.0] + + manager.set_active_adapters([], []) + + assert manager._active_adapter_ids == [] + assert manager._active_adapter_scales == [] + + +def test_multi_adapter_skips_zero_scale(monkeypatch): + """Adapters with scale 0 should be filtered out.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=2, + max_cached_adapters=2, + ) + + layer = _DummyLoRALayer(n_slices=1, output_slices=(4,)) + manager._lora_modules = {"transformer.foo": layer} + + rank = 2 + lora = LoRALayerWeights( + module_name="foo", + rank=rank, + lora_alpha=rank, + lora_a=torch.ones((rank, 4)), + lora_b=torch.ones((4, rank)), + ) + for aid in [1, 2]: + manager._registered_adapters[aid] = type( + "LM", (), {"id": aid, "loras": {"transformer.foo": lora}, "get_lora": lambda self, k: self.loras.get(k)} + )() + + manager.set_active_adapters( + [_dummy_lora_request(1), _dummy_lora_request(2)], + [1.0, 0.0], # adapter 2 has scale 0 + ) + + # Only adapter 1 should be active + assert manager._active_adapter_ids == [1] + assert layer._n_active_adapters == 1 + + +def test_multi_adapter_are_active_at_scales(): + """Test the _are_active_at_scales comparison.""" + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_loras=3, + ) + manager._active_adapter_ids = [1, 2] + manager._active_adapter_scales = [0.5, 0.75] + + assert manager._are_active_at_scales([1, 2], [0.5, 0.75]) + assert not manager._are_active_at_scales([1, 2], [0.5, 0.8]) + assert not manager._are_active_at_scales([1], [0.5]) + assert not manager._are_active_at_scales([2, 1], [0.75, 0.5]) + + def test_lora_manager_discovers_bagel_component(monkeypatch): """Verify that _replace_layers_with_lora finds layers under 'bagel'.""" import vllm_omni.diffusion.lora.manager as manager_mod diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py index 027dadb3f4..a71a3cb22a 100644 --- a/tests/e2e/offline_inference/test_diffusion_lora.py +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -25,56 +25,73 @@ models = ["Tongyi-MAI/Z-Image-Turbo"] +def _extract_images(outputs: list[OmniRequestOutput]): + if not outputs: + raise ValueError("Empty outputs from Omni.generate()") + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + return req_out.images + + +def _write_zimage_lora( + adapter_dir: Path, + *, + lora_b_value: float = 0.1, + lora_b_slice: str = "q", +) -> str: + """Write a fake Z-Image PEFT adapter to disk. + + Args: + adapter_dir: Directory to write the adapter files. + lora_b_value: Value to fill in the active lora_b slice. + lora_b_slice: Which QKV slice to perturb ("q", "k", or "v"). + """ + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default (see ZImageTransformer2DModel). + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V). With tp=1 and n_kv_heads==n_heads in Z-Image, + # each slice is `dim`, so total out dim is `3 * dim`. + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + slice_offsets = {"q": 0, "k": dim, "v": 2 * dim} + offset = slice_offsets[lora_b_slice] + lora_b[offset : offset + dim, 0] = lora_b_value + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + return str(adapter_dir) + + +@pytest.mark.diffusion +@pytest.mark.advanced_model @pytest.mark.parametrize("model_name", models) def test_diffusion_model(model_name: str, tmp_path: Path): - def _extract_images(outputs: list[OmniRequestOutput]): - if not outputs: - raise ValueError("Empty outputs from Omni.generate()") - first_output = outputs[0] - assert first_output.final_output_type == "image" - if not hasattr(first_output, "request_output") or not first_output.request_output: - raise ValueError("No request_output found in OmniRequestOutput") - - req_out = first_output.request_output - if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): - raise ValueError("Invalid request_output structure or missing 'images' key") - return req_out.images - - def _write_zimage_lora(adapter_dir: Path) -> str: - adapter_dir.mkdir(parents=True, exist_ok=True) - - # Z-Image transformer uses dim=3840 by default (see ZImageTransformer2DModel). - dim = 3840 - module_name = "transformer.layers.0.attention.to_qkv" - rank = 1 - lora_a = torch.zeros((rank, dim), dtype=torch.float32) - lora_a[0, 0] = 1.0 - - # QKVParallelLinear packs (Q, K, V). With tp=1 and n_kv_heads==n_heads in Z-Image, - # each slice is `dim`, so total out dim is `3 * dim`. - lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) - # Apply a visible delta to the Q slice only to keep the perturbation bounded. - lora_b[:dim, 0] = 0.1 - - save_file( - { - f"base_model.model.{module_name}.lora_A.weight": lora_a, - f"base_model.model.{module_name}.lora_B.weight": lora_b, - }, - str(adapter_dir / "adapter_model.safetensors"), - ) - (adapter_dir / "adapter_config.json").write_text( - json.dumps( - { - "r": rank, - "lora_alpha": rank, - "target_modules": [module_name], - } - ), - encoding="utf-8", - ) - return str(adapter_dir) - with OmniRunner(model_name) as runner: m = runner.omni # high resolution may cause OOM on L4 @@ -121,8 +138,8 @@ def _write_zimage_lora(adapter_dir: Path) -> str: guidance_scale=0.0, generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), num_outputs_per_prompt=1, - lora_request=lora_request, - lora_scale=2.0, + lora_requests=[lora_request], + lora_scales=[2.0], ), ) images_lora = _extract_images(outputs_lora) @@ -134,3 +151,74 @@ def _write_zimage_lora(adapter_dir: Path) -> str: diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean() assert diff > 0.0 + + +@pytest.mark.diffusion +@pytest.mark.advanced_model +@pytest.mark.parametrize("model_name", models) +def test_diffusion_multi_lora_composition(model_name: str, tmp_path: Path): + """Test that composing two LoRA adapters produces different output than either alone.""" + if model_name != "Tongyi-MAI/Z-Image-Turbo": + pytest.skip("Multi-LoRA composition test is Z-Image specific") + + with OmniRunner(model_name, max_loras=2) as runner: + m = runner.omni + from vllm_omni.lora.request import LoRARequest + from vllm_omni.lora.utils import stable_lora_int_id + + height = 256 + width = 256 + prompt = "a photo of a cat sitting on a laptop keyboard" + + # Create two adapters that perturb different QKV slices + lora_dir_a = _write_zimage_lora(tmp_path / "lora_a", lora_b_value=0.1, lora_b_slice="q") + lora_dir_b = _write_zimage_lora(tmp_path / "lora_b", lora_b_value=0.1, lora_b_slice="k") + + req_a = LoRARequest("lora_a", stable_lora_int_id(lora_dir_a), lora_dir_a) + req_b = LoRARequest("lora_b", stable_lora_int_id(lora_dir_b), lora_dir_b) + + def _gen(**lora_kwargs): + return _extract_images( + m.generate( + prompt, + OmniDiffusionSamplingParams( + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator(current_omni_platform.device_type).manual_seed(42), + num_outputs_per_prompt=1, + **lora_kwargs, + ), + ) + ) + + import numpy as np + + # Baseline: no LoRA + img_base = np.array(_gen()[0], dtype=np.int16) + + # Single LoRA A + img_a = np.array(_gen(lora_requests=[req_a], lora_scales=[2.0])[0], dtype=np.int16) + + # Single LoRA B + img_b = np.array(_gen(lora_requests=[req_b], lora_scales=[2.0])[0], dtype=np.int16) + + # Composed: A + B + img_ab = np.array( + _gen(lora_requests=[req_a, req_b], lora_scales=[2.0, 2.0])[0], + dtype=np.int16, + ) + + # All four outputs should differ from each other + diff_base_a = np.abs(img_base - img_a).mean() + diff_base_b = np.abs(img_base - img_b).mean() + diff_base_ab = np.abs(img_base - img_ab).mean() + diff_a_ab = np.abs(img_a - img_ab).mean() + diff_b_ab = np.abs(img_b - img_ab).mean() + + assert diff_base_a > 0.0, "LoRA A should differ from baseline" + assert diff_base_b > 0.0, "LoRA B should differ from baseline" + assert diff_base_ab > 0.0, "Composed A+B should differ from baseline" + assert diff_a_ab > 0.0, "Composed A+B should differ from A alone" + assert diff_b_ab > 0.0, "Composed A+B should differ from B alone" diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index cf6841fd21..ec065d5b30 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -399,7 +399,8 @@ class OmniDiffusionConfig: # LoRA parameters lora_path: str | None = None lora_scale: float = 1.0 - max_cpu_loras: int | None = None + max_loras: int = 1 # max adapters composed per request (GPU slot count) + max_cpu_loras: int | None = None # max adapters cached on CPU (LRU) output_type: str = "pil" @@ -650,10 +651,13 @@ def __post_init__(self): f"got {type(self.quantization_config)!r}" ) + if self.max_loras < 1: + raise ValueError("max_loras must be >= 1 for diffusion LoRA") + if self.max_cpu_loras is None: - self.max_cpu_loras = 1 - elif self.max_cpu_loras < 1: - raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + self.max_cpu_loras = max(self.max_loras, 1) + elif self.max_cpu_loras < self.max_loras: + raise ValueError(f"max_cpu_loras ({self.max_cpu_loras}) must be >= max_loras ({self.max_loras})") if self.diffusion_load_format != "diffusers" and (self.diffusers_load_kwargs or self.diffusers_call_kwargs): raise ValueError( diff --git a/vllm_omni/diffusion/lora/layers/base_linear.py b/vllm_omni/diffusion/lora/layers/base_linear.py index fe32868d08..32cbb58504 100644 --- a/vllm_omni/diffusion/lora/layers/base_linear.py +++ b/vllm_omni/diffusion/lora/layers/base_linear.py @@ -18,6 +18,9 @@ class DiffusionBaseLinearLayerWithLoRA(BaseLinearLayerWithLoRA): - Shrink: buffer = (x @ lora_a.T) - Expand: y += buffer @ lora_b.T + Multi-LoRA composition: multiple adapters can be active simultaneously, + each in its own slot; apply() accumulates deltas from all active slots. + All other functionality (weight management, TP slicing, forward logic) is inherited from vLLM's BaseLinearLayerWithLoRA. """ @@ -36,13 +39,19 @@ def create_lora_weights( modules = object.__getattribute__(self, "_modules") base_layer = modules.get("base_layer") or object.__getattribute__(self, "__dict__").get("base_layer") object.__setattr__(self, "_diffusion_base_layer_ref", base_layer) + self._n_active_adapters: int = 0 n_slices = getattr(self, "n_slices", 1) - self._diffusion_lora_active_slices = (False,) * int(n_slices) + # Per-adapter, per-slice active tracking: list of tuples + self._diffusion_lora_active_slices: list[tuple[bool, ...]] = [ + (False,) * int(n_slices) for _ in range(max_loras) + ] def reset_lora(self, index: int): super().reset_lora(index) n_slices = getattr(self, "n_slices", 1) - self._diffusion_lora_active_slices = (False,) * int(n_slices) + active_slices = getattr(self, "_diffusion_lora_active_slices", None) + if active_slices is not None and index < len(active_slices): + active_slices[index] = (False,) * int(n_slices) def set_lora( self, @@ -53,26 +62,30 @@ def set_lora( super().set_lora(index, lora_a, lora_b) # type: ignore[arg-type] n_slices = getattr(self, "n_slices", 1) + active_slices = getattr(self, "_diffusion_lora_active_slices", None) + if active_slices is None or index >= len(active_slices): + return + if isinstance(lora_a, list) or isinstance(lora_b, list): assert isinstance(lora_a, list) assert isinstance(lora_b, list) - active_slices = [] + slot_active = [] for a_i, b_i in zip(lora_a[:n_slices], lora_b[:n_slices]): - active_slices.append(a_i is not None and b_i is not None) - if len(active_slices) < n_slices: - active_slices.extend([False] * (n_slices - len(active_slices))) - self._diffusion_lora_active_slices = tuple(active_slices) + slot_active.append(a_i is not None and b_i is not None) + if len(slot_active) < n_slices: + slot_active.extend([False] * (n_slices - len(slot_active))) + active_slices[index] = tuple(slot_active) else: # Single-slice layer. - self._diffusion_lora_active_slices = (True,) + active_slices[index] = (True,) def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: """ override: Use simple matmul instead of punica_wrapper.add_lora_linear(). - This matches the exact computation in PunicaWrapperGPU.add_lora_linear() - for the single-LoRA case. For packed projections (e.g. fused QKV), we - apply LoRA per-slice using `output_slices`. + Supports multi-LoRA composition by accumulating deltas from all active + adapter slots. For packed projections (e.g. fused QKV), LoRA is applied + per-slice using `output_slices`. """ output = self.base_layer.quant_method.apply(self.base_layer, x, bias) @@ -80,9 +93,9 @@ def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tens return output if not self.lora_a_stacked or not self.lora_b_stacked: return output - # Fast path: if no LoRA is active for this layer, skip matmuls. - active_slices = getattr(self, "_diffusion_lora_active_slices", None) - if active_slices is not None and not any(active_slices): + + n_active = getattr(self, "_n_active_adapters", 0) + if n_active == 0: return output # In fully-sharded LoRA mode, vLLM uses an all-gather between shrink and @@ -111,25 +124,38 @@ def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tens f"lora_b_stacked={len(self.lora_b_stacked)}" ) - offset = 0 - for slice_idx, slice_size in enumerate(output_slices): - if active_slices is not None and slice_idx < len(active_slices) and not active_slices[slice_idx]: - offset += slice_size - continue + active_slices_list = getattr(self, "_diffusion_lora_active_slices", None) - A = self.lora_a_stacked[slice_idx][0, 0, :, :] # (rank, in_dim) - B = self.lora_b_stacked[slice_idx][0, 0, :, :] # (out_dim, rank) + for adapter_idx in range(n_active): + adapter_active_slices = ( + active_slices_list[adapter_idx] + if active_slices_list is not None and adapter_idx < len(active_slices_list) + else None + ) - if A.numel() == 0 or B.numel() == 0: + offset = 0 + for slice_idx, slice_size in enumerate(output_slices): + if ( + adapter_active_slices is not None + and slice_idx < len(adapter_active_slices) + and not adapter_active_slices[slice_idx] + ): + offset += slice_size + continue + + A = self.lora_a_stacked[slice_idx][adapter_idx, 0, :, :] # (rank, in_dim) + B = self.lora_b_stacked[slice_idx][adapter_idx, 0, :, :] # (out_dim, rank) + + if A.numel() == 0 or B.numel() == 0: + offset += slice_size + continue + + # LoRA shrink & expand: + # buffer = (x @ A.T) + # y += buffer @ B.T + delta = (x_flat @ A.t()) @ B.t() + y_flat[:, offset : offset + slice_size] = y_flat[:, offset : offset + slice_size] + delta offset += slice_size - continue - - # LoRA shrink & expand as in add_lora_linear(): - # buffer = (x @ A.T) - # y += buffer @ B.T - delta = (x_flat @ A.t()) @ B.t() - y_flat[:, offset : offset + slice_size] = y_flat[:, offset : offset + slice_size] + delta - offset += slice_size return y_flat.view(original_shape) diff --git a/vllm_omni/diffusion/lora/manager.py b/vllm_omni/diffusion/lora/manager.py index 63e8d9a96f..844b18b355 100644 --- a/vllm_omni/diffusion/lora/manager.py +++ b/vllm_omni/diffusion/lora/manager.py @@ -38,6 +38,7 @@ class DiffusionLoRAManager: Reuses vLLM's LoRA infrastructure, adapted for diffusion pipelines. Uses LRU cache management similar to LRUCacheLoRAModelManager. + Supports multi-LoRA composition: multiple adapters active simultaneously. """ # Valid max allowed ranks for LoRA in vLLM @@ -48,6 +49,7 @@ def __init__( pipeline: nn.Module, device: torch.device, dtype: torch.dtype, + max_loras: int = 1, max_cached_adapters: int = 1, lora_path: str | None = None, lora_scale: float = 1.0, @@ -56,6 +58,8 @@ def __init__( Initialize the DiffusionLoRAManager. Args: + max_loras: Maximum number of LoRA adapters that can be composed + (active simultaneously) per request. Controls GPU buffer slot count. max_cached_adapters: Maximum number of LoRA adapters to keep in the CPU-side cache (LRU). This mirrors vLLM's `max_cpu_loras` and is exposed to users via `OmniDiffusionConfig.max_cpu_loras`. @@ -63,6 +67,7 @@ def __init__( self.pipeline = pipeline self.device = device self.dtype = dtype + self.max_loras = max_loras # Cache supported/expected module suffixes once, before any layer # replacement happens. After LoRA layers are injected, the original @@ -79,8 +84,9 @@ def __init__( # LRU-style cache management self.max_cached_adapters = max_cached_adapters # max_cpu_loras self._registered_adapters: dict[int, LoRAModel] = {} # adapter_id -> LoRAModel - self._active_adapter_id: int | None = None - self._adapter_scales: dict[int, float] = {} # adapter_id -> external scale + # Currently active adapter ids (ordered) and their scales + self._active_adapter_ids: list[int] = [] + self._active_adapter_scales: list[float] = [] # LRU cache tracking (adapter_id -> last_used_time) self._adapter_access_order: OrderedDict[int, float] = OrderedDict() @@ -94,9 +100,11 @@ def __init__( self._max_lora_rank: int = 0 logger.info( - "Initializing DiffusionLoRAManager: device=%s, dtype=%s, max_cached_adapters=%d, static_lora_path=%s", + "Initializing DiffusionLoRAManager: device=%s, dtype=%s, " + "max_loras=%d, max_cached_adapters=%d, static_lora_path=%s", device, dtype, + max_loras, max_cached_adapters, lora_path, ) @@ -108,7 +116,7 @@ def __init__( lora_int_id=stable_lora_int_id(lora_path), lora_path=lora_path, ) - self.set_active_adapter(init_request, lora_scale) + self.set_active_adapters([init_request], [lora_scale]) def _compute_supported_lora_modules(self) -> set[str]: """Compute supported LoRA module suffixes for this pipeline. @@ -210,65 +218,68 @@ def _get_packed_sublayer_suffixes(self, packed_module_suffix: str, n_slices: int return None return sub_suffixes - def set_active_adapter(self, lora_request: LoRARequest | None, lora_scale: float = 1.0) -> None: - """Set the active LoRA adapter for the pipeline. + def set_active_adapters( + self, + lora_requests: list[LoRARequest], + lora_scales: list[float], + ) -> None: + """Set the active LoRA adapters for the pipeline. Args: - lora_request: The LoRA request, or None to deactivate all adapters. - lora_scale: The external scale for the LoRA adapter. + lora_requests: List of LoRA requests. Empty list deactivates all. + lora_scales: Per-adapter scales, must match length of lora_requests. """ - if lora_request is None: - if self._active_adapter_id is None: - logger.debug("No lora_request provided and adapters are already inactive") - return - logger.debug("No lora_request provided, deactivating all LoRA adapters") + if not lora_requests: + logger.debug("No lora_requests provided, deactivating all LoRA adapters") self._deactivate_all_adapters() return - elif math.isclose(0.0, lora_scale): - if self._active_adapter_id is None: - logger.debug("Received LoRA scale 0 with adapters already inactive") - return - logger.warning("Received a request with LoRA scale 0; deactivating all LoRA adapters") + + if len(lora_requests) != len(lora_scales): + raise ValueError( + f"lora_requests ({len(lora_requests)}) and lora_scales ({len(lora_scales)}) must have the same length" + ) + + # scale=0.0 still occupies a slot; it is not equivalent to omitting the adapter. + if len(lora_requests) > self.max_loras: + raise ValueError(f"Requested {len(lora_requests)} adapters but max_loras={self.max_loras}") + + # Filter out zero-scale adapters + active_requests: list[LoRARequest] = [] + active_scales: list[float] = [] + for req, scale in zip(lora_requests, lora_scales): + if math.isclose(0.0, scale): + logger.debug("Skipping adapter %s with scale 0", req.lora_name) + continue + active_requests.append(req) + active_scales.append(scale) + + if not active_requests: + logger.warning("All adapters have scale 0; deactivating all LoRA adapters") self._deactivate_all_adapters() return - adapter_id = lora_request.lora_int_id - logger.debug( - "Setting active adapter: id=%d, name=%s, path=%s, scale=%.2f, cache_size=%d/%d", - adapter_id, - lora_request.lora_name, - lora_request.lora_path, - lora_scale, - len(self._registered_adapters), - self.max_cached_adapters, - ) - if adapter_id not in self._registered_adapters: - logger.info("Loading new adapter: id=%d, name=%s", adapter_id, lora_request.lora_name) - # Add the adapter + add to the cache - self.add_adapter(lora_request) - else: - # Just touch the cache access order - self._touch_adapter_info(adapter_id) + # Ensure all adapters are registered (loaded into cache) + adapter_ids: list[int] = [] + for req in active_requests: + adapter_id = req.lora_int_id + if adapter_id not in self._registered_adapters: + logger.info("Loading new adapter: id=%d, name=%s", adapter_id, req.lora_name) + self.add_adapter(req) + else: + self._touch_adapter_info(adapter_id) + adapter_ids.append(adapter_id) - self._activate_adapter(adapter_id, lora_scale) + self._activate_adapters(adapter_ids, active_scales) def _touch_adapter_info(self, adapter_id): """Update the current caching ordering info.""" self._adapter_access_order[adapter_id] = time.time() self._adapter_access_order.move_to_end(adapter_id) - def _update_adapter_scale(self, adapter_id: int, lora_scale: float): - """Update the adapter scale for a given adapter ID. To avoid potential - issues with using Floats as keys, for now, we round float values to - 3 decimal points. - """ - scale = DiffusionLoRAManager._get_rounded_scale(lora_scale) - self._adapter_scales[adapter_id] = scale - @staticmethod def _get_rounded_scale(lora_scale: float): - """Normalizes a lora scale for use as a key in the _adapter_scales - dict; for now we just round scales to 3 decimal places. + """Normalizes a lora scale for use as comparison; + for now we just round scales to 3 decimal places. """ return round(lora_scale, 3) @@ -337,6 +348,16 @@ def _get_packed_modules_list(self, module: nn.Module) -> list[str]: return ["0", "1"] return [] + def _make_lora_config(self) -> LoRAConfig: + """Build a LoRAConfig using current manager state.""" + return LoRAConfig( + max_lora_rank=self._max_lora_rank, + max_loras=self.max_loras, + max_cpu_loras=self.max_cached_adapters, + lora_dtype=self.dtype, + fully_sharded_loras=False, + ) + def _replace_layers_with_lora(self, peft_helper: PEFTHelper) -> None: self._ensure_max_lora_rank(peft_helper.r) @@ -357,14 +378,7 @@ def _matches_target(module_name: str) -> bool: return True return _match_target_modules(module_name, target_modules_list) - # dummy lora config - lora_config = LoRAConfig( - max_lora_rank=self._max_lora_rank, - max_loras=1, - max_cpu_loras=self.max_cached_adapters, - lora_dtype=self.dtype, - fully_sharded_loras=False, - ) + lora_config = self._make_lora_config() for component_name in ("transformer", "transformer_2", "dit", "bagel"): if not hasattr(self.pipeline, component_name): @@ -410,7 +424,7 @@ def _matches_target(module_name: str) -> bool: for module_name, full_module_name, module, packed_modules_list in pending_replacements: lora_layer = from_layer_diffusion( layer=module, - max_loras=1, + max_loras=self.max_loras, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=None, @@ -426,7 +440,7 @@ def _ensure_max_lora_rank(self, min_rank: int) -> None: We allocate per-layer LoRA buffers once when we first replace layers. If a later adapter has a larger rank, we need to reinitialize those - buffers and re-apply the currently active adapter. + buffers and re-apply the currently active adapters. """ if min_rank <= self._max_lora_rank: return @@ -439,24 +453,19 @@ def _ensure_max_lora_rank(self, min_rank: int) -> None: if not self._lora_modules: return - lora_config = LoRAConfig( - max_lora_rank=self._max_lora_rank, - max_loras=1, - max_cpu_loras=self.max_cached_adapters, - lora_dtype=self.dtype, - fully_sharded_loras=False, - ) + lora_config = self._make_lora_config() # Recreate per-layer buffers with the new maximum rank. for lora_layer in self._lora_modules.values(): - lora_layer.create_lora_weights(max_loras=1, lora_config=lora_config, model_config=None) + lora_layer.create_lora_weights(max_loras=self.max_loras, lora_config=lora_config, model_config=None) - # Re-apply active adapter if needed (buffers were reset). - if self._active_adapter_id is not None: - active_id = self._active_adapter_id - active_scale = self._adapter_scales[active_id] - self._active_adapter_id = None - self._activate_adapter(active_id, active_scale) + # Re-apply active adapters if needed (buffers were reset). + if self._active_adapter_ids: + saved_ids = list(self._active_adapter_ids) + saved_scales = list(self._active_adapter_scales) + self._active_adapter_ids = [] + self._active_adapter_scales = [] + self._activate_adapters(saved_ids, saved_scales) @classmethod def _get_smallest_valid_max_rank(cls, min_rank: int) -> int: @@ -494,137 +503,169 @@ def _get_lora_weights( module_suffix = full_module_name.split(".")[-1] return lora_model.get_lora(module_suffix) - def _is_active_at_scale(self, adapter_id: int, scale: float) -> bool: - """True if the adapter_id is active and the current scale matches.""" - rounded_scale = DiffusionLoRAManager._get_rounded_scale(scale) - is_active = self._active_adapter_id == adapter_id - matches_scale = self._adapter_scales.get(adapter_id) == rounded_scale - return is_active and matches_scale - - def _activate_adapter(self, adapter_id: int, scale: float) -> None: - if self._is_active_at_scale(adapter_id, scale): - logger.debug("Adapter %d already active at scale %.3f skipping", adapter_id, scale) - return + def _are_active_at_scales(self, adapter_ids: list[int], scales: list[float]) -> bool: + """True if the given adapters are already active at the given scales.""" + # TODO: order-sensitive — re-requesting the same adapter set in a + # different order forces full re-activation even though LoRA deltas + # compose via addition (commutative). Consider comparing as a + # (id, scale) multiset to skip redundant work. + if len(adapter_ids) != len(self._active_adapter_ids): + return False + for aid, scale, active_aid, active_scale in zip( + adapter_ids, scales, self._active_adapter_ids, self._active_adapter_scales + ): + if aid != active_aid: + return False + if self._get_rounded_scale(scale) != self._get_rounded_scale(active_scale): + return False + return True - logger.info("Activating adapter: id=%d", adapter_id) - lora_model = self._registered_adapters[adapter_id] + def _set_lora_for_layer( + self, + lora_layer: BaseLayerWithLoRA, + full_module_name: str, + slot_index: int, + lora_model: LoRAModel, + scale: float, + ) -> None: + """Set LoRA weights for a single adapter slot on a single layer. + + Dispatches across four shapes: (1) adapter has no entry for this + module — fall through to per-slice suffix lookup for packed layers, + otherwise reset the slot; (2) adapter entry is ``PackedLoRALayerWeights`` + with per-slice tensors already separated; (3) adapter entry is a single + fused tensor but the target layer is multi-slice, so split ``lora_b`` + along ``output_slices``; (4) single-slice fused — apply directly. + """ + lora_weights = self._get_lora_weights(lora_model, full_module_name) - # activate weights in each LoRA layer - for full_module_name, lora_layer in self._lora_modules.items(): - lora_weights = self._get_lora_weights(lora_model, full_module_name) - - if lora_weights is None: - n_slices = getattr(lora_layer, "n_slices", 1) - if n_slices > 1: - prefix, _, packed_suffix = full_module_name.rpartition(".") - sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, n_slices) - if sub_suffixes is None: - lora_layer.reset_lora(0) + # Case 1: no direct entry. Multi-slice layers may still match via the + # unpacked sub-suffixes (e.g. ``q_proj``/``k_proj``/``v_proj`` when + # target is a fused ``qkv_proj``); single-slice layers just miss. + if lora_weights is None: + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + prefix, _, packed_suffix = full_module_name.rpartition(".") + sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, n_slices) + if sub_suffixes is None: + lora_layer.reset_lora(slot_index) + return + + # Gather per-slice weights; each slice may independently miss. + sub_loras: list[LoRALayerWeights | None] = [] + any_found = False + for sub_suffix in sub_suffixes: + sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix + sub_lora = self._get_lora_weights(lora_model, sub_full_name) + if sub_lora is not None: + any_found = True + # Packed layers expect plain (non-packed) subloras. + if isinstance(sub_lora, PackedLoRALayerWeights): + sub_lora = None + sub_loras.append(sub_lora if isinstance(sub_lora, LoRALayerWeights) else None) + + if not any_found: + lora_layer.reset_lora(slot_index) + return + + # Build per-slice A/B lists; None slots leave that slice at zero. + lora_a_list: list[torch.Tensor | None] = [] + lora_b_list: list[torch.Tensor | None] = [] + for sub_lora in sub_loras: + if sub_lora is None: + lora_a_list.append(None) + lora_b_list.append(None) continue + lora_a_list.append(sub_lora.lora_a) + lora_b_list.append(sub_lora.lora_b * scale) - sub_loras: list[LoRALayerWeights | None] = [] - any_found = False - for sub_suffix in sub_suffixes: - sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix - sub_lora = self._get_lora_weights(lora_model, sub_full_name) - if sub_lora is not None: - any_found = True - # Packed layers expect plain (non-packed) subloras. - if isinstance(sub_lora, PackedLoRALayerWeights): - sub_lora = None - sub_loras.append(sub_lora if isinstance(sub_lora, LoRALayerWeights) else None) - - if not any_found: - lora_layer.reset_lora(0) - continue + lora_layer.set_lora(index=slot_index, lora_a=lora_a_list, lora_b=lora_b_list) + return + else: + # Single-slice layer with no adapter entry: this slot is inactive. + lora_layer.reset_lora(slot_index) + return - lora_a_list: list[torch.Tensor | None] = [] - lora_b_list: list[torch.Tensor | None] = [] - for sub_lora in sub_loras: - if sub_lora is None: - lora_a_list.append(None) - lora_b_list.append(None) - continue - lora_a_list.append(sub_lora.lora_a) - lora_b_list.append(sub_lora.lora_b * scale) - - lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) - logger.debug( - "Activated packed LoRA for %s via submodules=%s (scale=%.2f)", - full_module_name, - sub_suffixes, - scale, - ) - else: - lora_layer.reset_lora(0) - continue + # Case 2: packed LoRA weights already provide per-slice tensors. + if isinstance(lora_weights, PackedLoRALayerWeights): + lora_a_list = lora_weights.lora_a + lora_b_list = [ + None if b is None else b * scale # type: ignore[operator] + for b in lora_weights.lora_b + ] + lora_layer.set_lora(index=slot_index, lora_a=lora_a_list, lora_b=lora_b_list) + return - # Packed LoRA weights already provide per-slice tensors. - if isinstance(lora_weights, PackedLoRALayerWeights): - lora_a_list = lora_weights.lora_a - lora_b_list = [ - None if b is None else b * scale # type: ignore[operator] - for b in lora_weights.lora_b - ] - lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) - logger.debug( - "Activated packed LoRA for %s (scale=%.2f)", + # Case 3: fused (non-packed) weights targeting a multi-slice layer. + # Split B along ``output_slices`` so each slice receives its portion; + # A is shared across slices (standard PEFT fused-QKV convention). + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + output_slices = getattr(lora_layer, "output_slices", None) + if output_slices is None: + lora_layer.reset_lora(slot_index) + return + + total = sum(output_slices) + if lora_weights.lora_b.shape[0] != total: + # Shape mismatch means we can't safely split; skip this layer + # rather than silently produce garbage outputs. + logger.warning( + "Skipping LoRA for %s due to shape mismatch: lora_b[0]=%d != sum(output_slices)=%d", full_module_name, - scale, + lora_weights.lora_b.shape[0], + total, ) - continue + lora_layer.reset_lora(slot_index) + return - # Fused (non-packed) weights: if the layer is multi-slice, split B. - n_slices = getattr(lora_layer, "n_slices", 1) - if n_slices > 1: - output_slices = getattr(lora_layer, "output_slices", None) - if output_slices is None: - lora_layer.reset_lora(0) - continue + b_splits = list(torch.split(lora_weights.lora_b, list(output_slices), dim=0)) + lora_a_list = [lora_weights.lora_a] * n_slices + lora_b_list = [b * scale for b in b_splits] + lora_layer.set_lora(index=slot_index, lora_a=lora_a_list, lora_b=lora_b_list) + return - total = sum(output_slices) - if lora_weights.lora_b.shape[0] != total: - logger.warning( - "Skipping LoRA for %s due to shape mismatch: lora_b[0]=%d != sum(output_slices)=%d", - full_module_name, - lora_weights.lora_b.shape[0], - total, - ) - lora_layer.reset_lora(0) - continue + # Case 4: single-slice fused — apply A and scaled B directly. + scaled_lora_b = lora_weights.lora_b * scale + lora_layer.set_lora(index=slot_index, lora_a=lora_weights.lora_a, lora_b=scaled_lora_b) - b_splits = list(torch.split(lora_weights.lora_b, list(output_slices), dim=0)) - lora_a_list = [lora_weights.lora_a] * n_slices - lora_b_list = [b * scale for b in b_splits] - lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) - logger.debug( - "Activated fused LoRA for packed layer %s (scale=%.2f)", - full_module_name, - scale, - ) - continue + def _activate_adapters(self, adapter_ids: list[int], scales: list[float]) -> None: + """Activate multiple adapters simultaneously, each in its own slot.""" + if self._are_active_at_scales(adapter_ids, scales): + logger.debug("Adapters already active at requested scales, skipping") + return - scaled_lora_b = lora_weights.lora_b * scale - lora_layer.set_lora(index=0, lora_a=lora_weights.lora_a, lora_b=scaled_lora_b) - logger.debug( - "Activated LoRA for %s: lora_a shape=%s, lora_b shape=%s, scale=%.2f", - full_module_name, - lora_weights.lora_a.shape, - lora_weights.lora_b.shape, - scale, - ) + logger.info("Activating %d adapter(s): ids=%s", len(adapter_ids), adapter_ids) + + for full_module_name, lora_layer in self._lora_modules.items(): + # Set each active adapter into its slot + for slot_index, (adapter_id, scale) in enumerate(zip(adapter_ids, scales)): + lora_model = self._registered_adapters[adapter_id] + self._set_lora_for_layer(lora_layer, full_module_name, slot_index, lora_model, scale) + + # Reset unused slots + for slot_index in range(len(adapter_ids), self.max_loras): + lora_layer.reset_lora(slot_index) + + # Tell each layer how many adapters are active + n_active = len(adapter_ids) + for lora_layer in self._lora_modules.values(): + lora_layer._n_active_adapters = n_active # type: ignore[attr-defined] - self._active_adapter_id = adapter_id - self._update_adapter_scale(adapter_id, scale) + self._active_adapter_ids = list(adapter_ids) + self._active_adapter_scales = list(scales) def _deactivate_all_adapters(self) -> None: - if self._active_adapter_id is None: + if not self._active_adapter_ids: logger.debug("All adapters already inactive") return logger.info("Deactivating all adapters: %d layers", len(self._lora_modules)) for lora_layer in self._lora_modules.values(): - lora_layer.reset_lora(0) - self._active_adapter_id = None + for slot_index in range(self.max_loras): + lora_layer.reset_lora(slot_index) + lora_layer._n_active_adapters = 0 # type: ignore[attr-defined] + self._active_adapter_ids = [] + self._active_adapter_scales = [] logger.debug("All adapters deactivated") def _evict_for_new_adapter(self) -> None: @@ -687,11 +728,10 @@ def remove_adapter(self, adapter_id: int) -> bool: return False logger.info("Removing adapter: id=%d", adapter_id) - if self._active_adapter_id == adapter_id: + if adapter_id in self._active_adapter_ids: self._deactivate_all_adapters() del self._registered_adapters[adapter_id] - self._adapter_scales.pop(adapter_id, None) self._adapter_access_order.pop(adapter_id, None) self._pinned_adapters.discard(adapter_id) logger.debug( diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py index 4f62d72c9b..859dadd249 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py @@ -1275,7 +1275,7 @@ def forward( lora_int_id=1, lora_path=lora_path, ) - self.lora_manager.set_active_adapter(lora_request, lora_scale=1.0) + self.lora_manager.set_active_adapters([lora_request], [1.0]) # Change scheduler to use Stage 2 distilled sigmas as is new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py index 4cc65f7490..efb8466211 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py @@ -862,7 +862,7 @@ def forward( lora_int_id=1, lora_path=lora_path, ) - self.lora_manager.set_active_adapter(lora_request, lora_scale=1.0) + self.lora_manager.set_active_adapters([lora_request], [1.0]) # Change scheduler to use Stage 2 distilled sigmas as is new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index d36c5c644c..eeb5e81658 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -76,12 +76,13 @@ def _enrich_config(self) -> None: def _reconstruct_sampling_params(self, sampling_params_dict: dict) -> OmniDiffusionSamplingParams: """Reconstruct OmniDiffusionSamplingParams from a dict, handling LoRA.""" - lora_req = sampling_params_dict.get("lora_request") - if lora_req is not None: + lora_reqs = sampling_params_dict.get("lora_requests") + if lora_reqs: from vllm.lora.request import LoRARequest - if not isinstance(lora_req, LoRARequest): - sampling_params_dict["lora_request"] = msgspec.convert(lora_req, LoRARequest) + sampling_params_dict["lora_requests"] = [ + req if isinstance(req, LoRARequest) else msgspec.convert(req, LoRARequest) for req in lora_reqs + ] return OmniDiffusionSamplingParams(**sampling_params_dict) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 927bbeb1a2..f47d61a1f8 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -214,6 +214,7 @@ def init_lora_manager(self) -> None: pipeline=self.model_runner.pipeline, device=self.device, dtype=self.od_config.dtype, + max_loras=self.od_config.max_loras, max_cached_adapters=self.od_config.max_cpu_loras, lora_path=self.od_config.lora_path, lora_scale=self.od_config.lora_scale, @@ -249,9 +250,12 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi assert self.model_runner is not None, "Model runner not initialized" if self.lora_manager is not None: try: - self.lora_manager.set_active_adapter(req.sampling_params.lora_request, req.sampling_params.lora_scale) + self.lora_manager.set_active_adapters( + req.sampling_params.lora_requests, + req.sampling_params.lora_scales, + ) except Exception as exc: - if req.sampling_params.lora_request is not None: + if req.sampling_params.lora_requests: raise logger.warning("LoRA activation skipped: %s", exc) profiler = self._get_profiler() @@ -268,9 +272,9 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner if self.lora_manager is not None: # Step mode does not support LoRA yet. Clear any previously active # adapter first so worker-local LoRA state cannot leak in. - self.lora_manager.set_active_adapter(None) + self.lora_manager.set_active_adapters([], []) - if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): + if any(new_req.req.sampling_params.lora_requests for new_req in scheduler_output.scheduled_new_reqs): raise ValueError("Step mode does not support LoRA yet.") profiler = self._get_profiler() ctx = profiler.annotate_context_manager("diffusion_step") if profiler else nullcontext() diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index a37afd24b4..6546d987f8 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1320,6 +1320,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), "quantization": kwargs.get("quantization", None), + "max_loras": kwargs.get("max_loras", 1), "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False), "enable_ar_profiler": kwargs.get("enable_ar_profiler", False), **( @@ -1457,6 +1458,10 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st if lora_scale is not None: if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: cfg.engine_args.lora_scale = lora_scale + max_loras = kwargs.get("max_loras") + if max_loras is not None: + if not hasattr(cfg.engine_args, "max_loras") or cfg.engine_args.max_loras is None: + cfg.engine_args.max_loras = max_loras quantization_config = kwargs.get("quantization_config") if quantization_config is not None: if ( diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 646bbd6f91..508df36f5c 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -120,7 +120,7 @@ from vllm_omni.entrypoints.openai.serving_video_stream import OmniStreamingVideoHandler from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS -from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request +from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_requests from vllm_omni.entrypoints.openai.video_api_utils import decode_input_reference from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt @@ -1475,7 +1475,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) extra_body["generator_device"] = request.generator_device if request.lora is not None: # Keep /images validation semantics: invalid LoRA should fail with 400. - _parse_lora_request(request.lora) + _parse_lora_requests(request.lora) extra_body["lora"] = request.lora generation_result = await chat_handler.generate_diffusion_images( @@ -1505,9 +1505,10 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) if extra_args: gen_params.extra_args = extra_args # Parse per-request LoRA (compatible with chat's extra_body.lora shape). - lora_request, lora_scale = _parse_lora_request(request.lora) - _update_if_not_none(gen_params, "lora_request", lora_request) - _update_if_not_none(gen_params, "lora_scale", lora_scale) + lora_reqs, lora_scales = _parse_lora_requests(request.lora) + if lora_reqs: + gen_params.lora_requests = lora_reqs + gen_params.lora_scales = lora_scales # Parse and add size if provided width, height = None, None @@ -1726,9 +1727,10 @@ async def edit_images( _update_if_not_none(gen_params, "num_outputs_per_prompt", n) # 3.1 Parse per-request LoRA (compatible with chat's extra_body.lora shape). lora_dict = _get_lora_from_json_str(lora) - lora_request, lora_scale = _parse_lora_request(lora_dict) - _update_if_not_none(gen_params, "lora_request", lora_request) - _update_if_not_none(gen_params, "lora_scale", lora_scale) + lora_reqs, lora_scales = _parse_lora_requests(lora_dict) + if lora_reqs: + gen_params.lora_requests = lora_reqs + gen_params.lora_scales = lora_scales # 3.2 Validate resolution if provided if resolution is not None and resolution not in SUPPORTED_LAYERED_RESOLUTIONS: raise HTTPException( @@ -1845,7 +1847,7 @@ async def edit_images( if lora is not None: # Validate LoRA, then pass through. lora_dict = _get_lora_from_json_str(lora) - _parse_lora_request(lora_dict) + _parse_lora_requests(lora_dict) extra_body["lora"] = lora_dict prompt_text = prompt.get("prompt", "") @@ -1988,15 +1990,15 @@ def _get_lora_from_json_str(lora_body): except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid LoRA JSON string") - if not isinstance(lora_dict, dict): - raise HTTPException(status_code=400, detail="LoRA must be a JSON object") + if not isinstance(lora_dict, (dict, list)): + raise HTTPException(status_code=400, detail="LoRA must be a JSON object or array of objects") return lora_dict -def _parse_lora_request(lora_body: dict[str, Any]): +def _parse_lora_requests(lora_body: dict[str, Any] | list[dict[str, Any]] | None): try: - return parse_lora_request(lora_body) + return parse_lora_requests(lora_body) except ValueError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, diff --git a/vllm_omni/entrypoints/openai/protocol/images.py b/vllm_omni/entrypoints/openai/protocol/images.py index 0fb22a548c..a4b9eef99f 100644 --- a/vllm_omni/entrypoints/openai/protocol/images.py +++ b/vllm_omni/entrypoints/openai/protocol/images.py @@ -125,10 +125,11 @@ def validate_use_system_prompt(cls, v): # vllm-omni extension for per-request LoRA. # This mirrors the `extra_body.lora` convention in /v1/chat/completions. - lora: dict[str, Any] | None = Field( + lora: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description=( - "Optional LoRA adapter for this request. Expected shape: " + "Optional LoRA adapter(s) for this request. " + "A single dict or a list of dicts, each with shape: " "{name/path/scale/int_id}. Field names are flexible " "(e.g. name|lora_name|adapter, path|lora_path|local_path, " "scale|lora_scale, int_id|lora_int_id)." diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index d46c8d43d6..5799fb4d56 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -174,10 +174,11 @@ class VideoGenerationRequest(BaseModel): ) # vllm-omni extension for per-request LoRA. - lora: dict[str, Any] | None = Field( + lora: dict[str, Any] | list[dict[str, Any]] | None = Field( default=None, description=( - "Optional LoRA adapter for this request. Expected shape: " + "Optional LoRA adapter(s) for this request. " + "A single dict or a list of dicts, each with shape: " "{name/path/scale/int_id}. Field names are flexible " "(e.g. name|lora_name|adapter, path|lora_path|local_path, " "scale|lora_scale, int_id|lora_int_id)." diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index fd6484e6df..bdcb618d85 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -90,7 +90,7 @@ from vllm_omni.entrypoints.openai.utils import ( get_stage_type, get_supported_speakers_from_hf_config, - parse_lora_request, + parse_lora_requests, validate_requested_speaker, ) from vllm_omni.lora.request import LoRARequest @@ -2224,13 +2224,12 @@ def _build_multistage_generation_inputs( layers=layers, resolution=resolution, ) - if lora_body and isinstance(lora_body, dict): + if lora_body and isinstance(lora_body, (dict, list)): try: - lora_req, lora_scale = parse_lora_request(lora_body) - if lora_req is not None: - default_stage_params.lora_request = lora_req - if lora_scale is not None: - default_stage_params.lora_scale = lora_scale + lora_reqs, lora_scales = parse_lora_requests(lora_body) + if lora_reqs: + default_stage_params.lora_requests = lora_reqs + default_stage_params.lora_scales = lora_scales except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) @@ -2291,13 +2290,12 @@ async def generate_diffusion_images( strength=extra_body.get("strength"), ) - if lora_body and isinstance(lora_body, dict): + if lora_body and isinstance(lora_body, (dict, list)): try: - lora_req, lora_scale = parse_lora_request(lora_body) - if lora_req is not None: - gen_params.lora_request = lora_req - if lora_scale is not None: - gen_params.lora_scale = lora_scale + lora_reqs, lora_scales = parse_lora_requests(lora_body) + if lora_reqs: + gen_params.lora_requests = lora_reqs + gen_params.lora_scales = lora_scales except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) @@ -2488,13 +2486,12 @@ async def _create_diffusion_chat_completion( gen_params.resolution = resolution # Parse per-request LoRA. - if lora_body and isinstance(lora_body, dict): + if lora_body: try: - lora_req, lora_scale = parse_lora_request(lora_body) - if lora_req is not None: - gen_params.lora_request = lora_req - if lora_scale is not None: - gen_params.lora_scale = lora_scale + lora_reqs, lora_scales = parse_lora_requests(lora_body) + if lora_reqs: + gen_params.lora_requests = lora_reqs + gen_params.lora_scales = lora_scales except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index a4be330eb4..d679f63cbf 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -20,7 +20,7 @@ VideoGenerationRequest, VideoGenerationResponse, ) -from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request +from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_requests from vllm_omni.entrypoints.openai.video_api_utils import _encode_video_bytes, encode_video_base64 from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams, OmniTextPrompt @@ -283,19 +283,18 @@ def _resolve_default_sampling_params(self) -> OmniDiffusionSamplingParams: @staticmethod def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None: try: - lora_request, lora_scale = parse_lora_request(lora_body) + lora_reqs, lora_scales = parse_lora_requests(lora_body) except ValueError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e), ) from e - if lora_request is None: + if not lora_reqs: return - gen_params.lora_request = lora_request - if lora_scale is not None: - gen_params.lora_scale = lora_scale + gen_params.lora_requests = lora_reqs + gen_params.lora_scales = lora_scales async def _run_generation( self, diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py index f411526fdb..d4258138c9 100644 --- a/vllm_omni/entrypoints/openai/utils.py +++ b/vllm_omni/entrypoints/openai/utils.py @@ -55,6 +55,42 @@ def parse_lora_request(lora_body: Any) -> tuple[LoRARequest | None, float | None return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), scale +def parse_lora_requests( + lora_body: dict[str, Any] | list[dict[str, Any]] | None, +) -> tuple[list[LoRARequest], list[float]]: + """Parse one or more LoRA objects into parallel lists of requests and scales. + + Handles three shapes: + - ``None`` -> empty lists + - ``dict`` -> single adapter, wrapped in a one-element list + - ``list`` -> multiple adapters + + Returns: + ``(lora_requests, lora_scales)`` with matching lengths. + """ + if lora_body is None: + return [], [] + + items: list[dict[str, Any]] + if isinstance(lora_body, dict): + items = [lora_body] + elif isinstance(lora_body, list): + items = lora_body + else: + raise ValueError("Invalid lora field: expected a dict, list of dicts, or null.") + + requests: list[LoRARequest] = [] + scales: list[float] = [] + for item in items: + req, scale = parse_lora_request(item) + if req is None: + continue + requests.append(req) + scales.append(scale if scale is not None else 1.0) + + return requests, scales + + def get_supported_speakers_from_hf_config(hf_config: Any) -> set[str]: """Extract supported speaker names from a model hf_config.""" config = ( diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index e4c33a58c2..b0127e5acd 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -287,9 +287,9 @@ class OmniDiffusionSamplingParams: save_output: bool = True return_frames: bool = False - # LoRA - lora_request: LoRARequest | None = None - lora_scale: float = 1.0 + # LoRA — multiple adapters can be composed per request + lora_requests: list[LoRARequest] = field(default_factory=list) + lora_scales: list[float] = field(default_factory=list) # STA parameters STA_param: list | None = None @@ -334,6 +334,13 @@ def resolved_frame_rate(self) -> float | None: return float(fps) + def __post_init__(self): + # Default per-adapter LoRA scales to 1.0 when adapters are supplied + # without explicit scales, so callers can opt in by setting + # lora_requests alone. + if self.lora_requests and not self.lora_scales: + self.lora_scales = [1.0] * len(self.lora_requests) + def __str__(self): return pprint.pformat(asdict(self), indent=2, width=120)