diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 8f942b0411b..ca7ed79200c 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -41,6 +41,23 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Images API LoRA E2E" + timeout_in_minutes: 20 + depends_on: image-build + commands: + - pytest -s -v tests/e2e/online_serving/test_images_generations_lora.py + agents: + queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + environment: + - "HF_HOME=/fsx/hf_cache" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Diffusion Model CPU offloading Test" timeout_in_minutes: 20 depends_on: image-build diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh index 9ac1cb9e434..e76a5c70d14 100755 --- a/.buildkite/scripts/simple_test.sh +++ b/.buildkite/scripts/simple_test.sh @@ -52,6 +52,7 @@ VENV_PYTHON="${VENV_DIR}/bin/python" "${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/ "${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/ +"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/lora/ "${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py "${VENV_PYTHON}" -m pytest -v -s tests/worker/ "${VENV_PYTHON}" -m pytest -v -s tests/distributed/omni_connectors/test_kv_flow.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e8c3d42d63..9e13098f648 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,10 @@ repos: # only for staged files - repo: https://github.com/rhysd/actionlint - rev: v1.7.9 + # v1.7.8+ sets `go 1.24.0` in go.mod, which older Go toolchains (and most + # current CI images) cannot parse. Pin to v1.7.7 until actionlint fixes the + # go.mod directive. + rev: v1.7.7 hooks: - id: actionlint files: ^\.github/workflows/.*\.ya?ml$ diff --git a/docs/user_guide/examples/offline_inference/lora_inference.md b/docs/user_guide/examples/offline_inference/lora_inference.md new file mode 100644 index 00000000000..dde42655e44 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/lora_inference.md @@ -0,0 +1,107 @@ +# LoRA-Inference + +Source . + +This contains examples for using LoRA (Low-Rank Adaptation) adapters with vLLM-omni diffusion models for offline inference. +The example uses the `stabilityai/stable-diffusion-3.5-medium` as the default model, but you can replace it with other models in vLLM-omni. + +## Overview + +Similar to vLLM, vLLM-omni uses a unified LoRA handling mechanism: + +- **Pre-loaded LoRA**: Loaded at initialization via `--lora-path` (pre-loaded into cache) +- **Per-request LoRA**: Loaded on-demand. In the example, the LoRA is loaded via `--lora-request-path` in each request + +Both approaches use the same underlying mechanism - all LoRA adapters are handled uniformly through `set_active_adapter()`. If no LoRA request is provided in a request, all adapters are deactivated. + +## Usage + +### Pre-loaded LoRA (via --lora-path) + +Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_preloaded.png +``` + +**Note**: When using `--lora-path`, the adapter is loaded at init time with a stable ID derived from the adapter path. This example activates it automatically for the request. + +### Per-request LoRA (via --lora-request-path) + +Load a LoRA adapter on-demand for each request: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-request-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_per_request.png +``` + +### No LoRA + +If no LoRA request is provided, we will use the base model without any LoRA adapters: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_no_lora.png +``` + +## Parameters + +### LoRA Parameters + +- `--lora-path`: Path to LoRA adapter folder to pre-load at initialization (loads into cache with a stable ID derived from the path) +- `--lora-request-path`: Path to LoRA adapter folder for per-request loading +- `--lora-request-id`: Integer ID for the LoRA adapter (optional). If not provided and `--lora-request-path` is set, will derive a stable ID from the path. +- `--lora-scale`: Scale factor for LoRA weights (default: 1.0). Higher values increase the influence of the LoRA adapter. + +### Standard Parameters + +- `--prompt`: Text prompt for image generation (required) +- `--seed`: Random seed for reproducibility (default: 42) +- `--height`: Image height in pixels (default: 1024) +- `--width`: Image width in pixels (default: 1024) +- `--num_inference_steps`: Number of denoising steps (default: 50) +- `--output`: Output file path (default: `lora_output.png`) + +## How LoRA Works + +All LoRA adapters are handled uniformly: + +1. **Initialization**: If `--lora-path` is provided, the adapter is loaded into cache with a stable ID derived from the adapter path +2. **Per-request**: If `--lora-request-path` is provided, the adapter is loaded/activated for that request +3. **No LoRA**: If no LoRA request is provided (`req.lora_request` is None), all adapters are deactivated + +The system uses LRU cache management - adapters are cached and evicted when the cache is full (unless pinned). + +## LoRA Adapter Format + +LoRA adapters must be in PEFT (Parameter-Efficient Fine-Tuning) format. A typical LoRA adapter directory structure: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` + +## Example materials + +??? abstract "lora_inference.py" + ``````py + --8<-- "examples/offline_inference/lora_inference/lora_inference.py" + `````` diff --git a/docs/user_guide/examples/online_serving/lora_inference.md b/docs/user_guide/examples/online_serving/lora_inference.md new file mode 100644 index 00000000000..4c8b215d299 --- /dev/null +++ b/docs/user_guide/examples/online_serving/lora_inference.md @@ -0,0 +1,69 @@ +# LoRA-Inference + +Source . + +This example shows how to use **per-request LoRA** with vLLM-Omni diffusion models via the OpenAI-compatible Chat Completions API. + +> Note: The LoRA adapter path must be readable on the **server** machine (usually a local path or a mounted directory). +> Note: This example uses `/v1/chat/completions`. LoRA payloads for other OpenAI endpoints are not implemented here. + +## Start Server + +```bash +# Pick a diffusion model (examples) +# export MODEL=stabilityai/stable-diffusion-3.5-medium +# export MODEL=Qwen/Qwen-Image + +bash run_server.sh +``` + +## Call API (curl) + +```bash +# Required: local LoRA folder on the server +export LORA_PATH=/path/to/lora_adapter + +# Optional +export SERVER=http://localhost:8091 +export PROMPT="A piece of cheesecake" +export LORA_NAME=my_lora +export LORA_SCALE=1.0 +# Optional: if omitted, the server derives a stable id from LORA_PATH. +# export LORA_INT_ID=123 + +bash run_curl_lora_inference.sh +``` + +## Call API (Python) + +```bash +python openai_chat_client.py \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +``` + +## LoRA Format + +LoRA adapters should be in PEFT format, for example: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` + +??? abstract "openai_chat_client.py" + ``````py + --8<-- "examples/online_serving/lora_inference/openai_chat_client.py" + `````` +??? abstract "run_curl_lora_inference.sh" + ``````py + --8<-- "examples/online_serving/lora_inference/run_curl_lora_inference.sh" + `````` +??? abstract "run_server.sh" + ``````py + --8<-- "examples/online_serving/lora_inference/run_server.sh" + `````` diff --git a/examples/offline_inference/lora_inference/README.md b/examples/offline_inference/lora_inference/README.md new file mode 100644 index 00000000000..b0b195f6e6f --- /dev/null +++ b/examples/offline_inference/lora_inference/README.md @@ -0,0 +1,98 @@ +# LoRA Inference Examples + +This directory contains examples for using LoRA (Low-Rank Adaptation) adapters with vLLM-omni diffusion models for offline inference. +The example uses the `stabilityai/stable-diffusion-3.5-medium` as the default model, but you can replace it with other models in vLLM-omni. + +## Overview + +Similar to vLLM, vLLM-omni uses a unified LoRA handling mechanism: + +- **Pre-loaded LoRA**: Loaded at initialization via `--lora-path` (pre-loaded into cache) +- **Per-request LoRA**: Loaded on-demand. In the example, the LoRA is loaded via `--lora-request-path` in each request + +Both approaches use the same underlying mechanism - all LoRA adapters are handled uniformly through `set_active_adapter()`. If no LoRA request is provided in a request, all adapters are deactivated. + +## Usage + +### Pre-loaded LoRA (via --lora-path) + +Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_preloaded.png +``` + +**Note**: When using `--lora-path`, the adapter is loaded at init time with a stable ID derived from the adapter path. This example activates it automatically for the request. + +### Per-request LoRA (via --lora-request-path) + +Load a LoRA adapter on-demand for each request: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --lora-request-path /path/to/lora/ \ + --lora-scale 1.0 \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_per_request.png +``` + +### No LoRA + +If no LoRA request is provided, we will use the base model without any LoRA adapters: + +```bash +python -m examples.offline_inference.lora_inference.lora_inference \ + --prompt "A piece of cheesecake" \ + --num_inference_steps 50 \ + --height 1024 \ + --width 1024 \ + --output output_no_lora.png +``` + +## Parameters + +### LoRA Parameters + +- `--lora-path`: Path to LoRA adapter folder to pre-load at initialization (loads into cache with a stable ID derived from the path) +- `--lora-request-path`: Path to LoRA adapter folder for per-request loading +- `--lora-request-id`: Integer ID for the LoRA adapter (optional). If not provided and `--lora-request-path` is set, will derive a stable ID from the path. +- `--lora-scale`: Scale factor for LoRA weights (default: 1.0). Higher values increase the influence of the LoRA adapter. + +### Standard Parameters + +- `--prompt`: Text prompt for image generation (required) +- `--seed`: Random seed for reproducibility (default: 42) +- `--height`: Image height in pixels (default: 1024) +- `--width`: Image width in pixels (default: 1024) +- `--num_inference_steps`: Number of denoising steps (default: 50) +- `--output`: Output file path (default: `lora_output.png`) + +## How LoRA Works + +All LoRA adapters are handled uniformly: + +1. **Initialization**: If `--lora-path` is provided, the adapter is loaded into cache with a stable ID derived from the adapter path +2. **Per-request**: If `--lora-request-path` is provided, the adapter is loaded/activated for that request +3. **No LoRA**: If no LoRA request is provided (`req.lora_request` is None), all adapters are deactivated + +The system uses LRU cache management - adapters are cached and evicted when the cache is full (unless pinned). + +## LoRA Adapter Format + +LoRA adapters must be in PEFT (Parameter-Efficient Fine-Tuning) format. A typical LoRA adapter directory structure: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` diff --git a/examples/offline_inference/lora_inference/lora_inference.py b/examples/offline_inference/lora_inference/lora_inference.py new file mode 100644 index 00000000000..17e9d6196dd --- /dev/null +++ b/examples/offline_inference/lora_inference/lora_inference.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +from pathlib import Path + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Generate images with LoRA adapters.") + parser.add_argument("--model", default="stabilityai/stable-diffusion-3.5-medium", help="Model name or path.") + parser.add_argument("--prompt", required=True, help="Text prompt for image generation.") + parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.") + parser.add_argument("--height", type=int, default=1024, help="Height of generated image.") + parser.add_argument("--width", type=int, default=1024, help="Width of generated image.") + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Number of denoising steps for the diffusion sampler.", + ) + parser.add_argument( + "--output", + type=str, + default="lora_output.png", + help="Path to save the generated image (PNG).", + ) + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to LoRA adapter folder to pre-load at initialization (PEFT format). " + "Note: pre-loading populates the cache; you still need to pass a lora_request to activate it.", + ) + parser.add_argument( + "--lora-request-path", + type=str, + default=None, + help="Path to LoRA adapter folder for per-request activation (dynamic LoRA). " + "If --lora-request-id is not provided, a stable ID will be derived from this path.", + ) + parser.add_argument( + "--lora-request-id", + type=int, + default=None, + help="Integer ID for the LoRA adapter (for dynamic LoRA). " + "If not provided and --lora-request-path is set, will derive a stable ID from the path.", + ) + parser.add_argument( + "--lora-scale", + type=float, + default=1.0, + help="Scale factor for LoRA weights (default: 1.0).", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + model = args.model + + omni_kwargs = {} + + if args.lora_path: + omni_kwargs["lora_path"] = args.lora_path + print(f"Using static LoRA from: {args.lora_path}") + + omni = Omni(model=model, **omni_kwargs) + + lora_request = None + if args.lora_request_path: + if args.lora_request_id is None: + lora_request_id = stable_lora_int_id(args.lora_request_path) + else: + lora_request_id = args.lora_request_id + + lora_name = Path(args.lora_request_path).stem + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=lora_request_id, + lora_path=args.lora_request_path, + ) + print(f"Using per-request LoRA: name={lora_name}, id={lora_request_id}, scale={args.lora_scale}") + elif args.lora_path: + # pre-loaded LoRA + lora_request_id = stable_lora_int_id(args.lora_path) + lora_request = LoRARequest( + lora_name="preloaded", + lora_int_id=lora_request_id, + lora_path=args.lora_path, + ) + print(f"Activating pre-loaded LoRA: id={lora_request_id}, scale={args.lora_scale}") + + gen_kwargs = { + "prompt": args.prompt, + "height": args.height, + "width": args.width, + "num_inference_steps": args.num_inference_steps, + } + + if lora_request: + gen_kwargs["lora_request"] = lora_request + gen_kwargs["lora_scale"] = args.lora_scale + + outputs = omni.generate(**gen_kwargs) + + if not outputs or len(outputs) == 0: + raise ValueError("No output generated from omni.generate()") + + if isinstance(outputs, list): + first_output = outputs[0] + else: + first_output = outputs + + images = None + if hasattr(first_output, "images") and first_output.images: + images = first_output.images + elif hasattr(first_output, "request_output") and first_output.request_output: + req_out = first_output.request_output + if isinstance(req_out, list) and len(req_out) > 0: + req_out = req_out[0] + if hasattr(req_out, "images") and req_out.images: + images = req_out.images + + if not images: + raise ValueError("No images found in request_output") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + suffix = output_path.suffix or ".png" + stem = output_path.stem or "lora_output" + if len(images) <= 1: + images[0].save(output_path) + print(f"Saved generated image to {output_path}") + else: + for idx, img in enumerate(images): + save_path = output_path.parent / f"{stem}_{idx}{suffix}" + img.save(save_path) + print(f"Saved generated image to {save_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lora_inference/README.md b/examples/online_serving/lora_inference/README.md new file mode 100644 index 00000000000..16ce55313dd --- /dev/null +++ b/examples/online_serving/lora_inference/README.md @@ -0,0 +1,54 @@ +# Online LoRA Inference (Diffusion) + +This example shows how to use **per-request LoRA** with vLLM-Omni diffusion models via the OpenAI-compatible Chat Completions API. + +> Note: The LoRA adapter path must be readable on the **server** machine (usually a local path or a mounted directory). +> Note: This example uses `/v1/chat/completions`. LoRA payloads for other OpenAI endpoints are not implemented here. + +## Start Server + +```bash +# Pick a diffusion model (examples) +# export MODEL=stabilityai/stable-diffusion-3.5-medium +# export MODEL=Qwen/Qwen-Image + +bash run_server.sh +``` + +## Call API (curl) + +```bash +# Required: local LoRA folder on the server +export LORA_PATH=/path/to/lora_adapter + +# Optional +export SERVER=http://localhost:8091 +export PROMPT="A piece of cheesecake" +export LORA_NAME=my_lora +export LORA_SCALE=1.0 +# Optional: if omitted, the server derives a stable id from LORA_PATH. +# export LORA_INT_ID=123 + +bash run_curl_lora_inference.sh +``` + +## Call API (Python) + +```bash +python openai_chat_client.py \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +``` + +## LoRA Format + +LoRA adapters should be in PEFT format, for example: + +``` +lora_adapter/ +├── adapter_config.json +└── adapter_model.safetensors +``` diff --git a/examples/online_serving/lora_inference/openai_chat_client.py b/examples/online_serving/lora_inference/openai_chat_client.py new file mode 100644 index 00000000000..e24d2fdf65b --- /dev/null +++ b/examples/online_serving/lora_inference/openai_chat_client.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +""" +OpenAI-compatible chat client for diffusion LoRA inference. + +Example: + python openai_chat_client.py \ + --server http://localhost:8091 \ + --prompt "A piece of cheesecake" \ + --lora-path /path/to/lora_adapter \ + --lora-name my_lora \ + --lora-scale 1.0 \ + --output output.png +""" + +import argparse +import base64 +from pathlib import Path + +import requests + + +def generate_image( + *, + prompt: str, + server_url: str, + height: int | None, + width: int | None, + num_inference_steps: int | None, + seed: int | None, + lora_name: str | None, + lora_path: str | None, + lora_scale: float | None, + lora_int_id: int | None, +) -> bytes | None: + messages = [{"role": "user", "content": prompt}] + + extra_body: dict = {} + if height is not None: + extra_body["height"] = height + if width is not None: + extra_body["width"] = width + if num_inference_steps is not None: + extra_body["num_inference_steps"] = num_inference_steps + if seed is not None: + extra_body["seed"] = seed + + if lora_path: + lora_body: dict = { + "local_path": lora_path, + "name": lora_name or Path(lora_path).stem, + } + if lora_scale is not None: + lora_body["scale"] = float(lora_scale) + if lora_int_id is not None: + lora_body["int_id"] = int(lora_int_id) + extra_body["lora"] = lora_body + + payload = {"messages": messages} + if extra_body: + payload["extra_body"] = extra_body + + response = requests.post( + f"{server_url}/v1/chat/completions", + headers={"Content-Type": "application/json"}, + json=payload, + timeout=300, + ) + response.raise_for_status() + data = response.json() + + content = data["choices"][0]["message"]["content"] + if isinstance(content, list) and content: + image_url = content[0].get("image_url", {}).get("url", "") + if image_url.startswith("data:image"): + _, b64_data = image_url.split(",", 1) + return base64.b64decode(b64_data) + + raise RuntimeError(f"Unexpected response format: {content!r}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Diffusion LoRA OpenAI chat client") + parser.add_argument("--server", default="http://localhost:8091", help="Server URL") + parser.add_argument("--prompt", default="A piece of cheesecake", help="Text prompt") + parser.add_argument("--output", default="lora_online_output.png", help="Output image path") + + parser.add_argument("--height", type=int, default=1024, help="Image height") + parser.add_argument("--width", type=int, default=1024, help="Image width") + parser.add_argument("--steps", type=int, default=50, help="num_inference_steps") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + parser.add_argument("--lora-path", default=None, help="Server-local LoRA adapter folder (PEFT format)") + parser.add_argument("--lora-name", default=None, help="LoRA name (optional)") + parser.add_argument("--lora-scale", type=float, default=1.0, help="LoRA scale") + parser.add_argument( + "--lora-int-id", + type=int, + default=None, + help="LoRA integer id (cache key). If omitted, the server derives a stable id from lora_path.", + ) + + args = parser.parse_args() + + image_bytes = generate_image( + prompt=args.prompt, + server_url=args.server, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + seed=args.seed, + lora_name=args.lora_name, + lora_path=args.lora_path, + lora_scale=args.lora_scale if args.lora_path else None, + lora_int_id=args.lora_int_id if args.lora_path else None, + ) + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_bytes(image_bytes) + print(f"Saved: {out_path} ({len(image_bytes) / 1024:.1f} KiB)") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/lora_inference/run_curl_lora_inference.sh b/examples/online_serving/lora_inference/run_curl_lora_inference.sh new file mode 100755 index 00000000000..14a074fbf87 --- /dev/null +++ b/examples/online_serving/lora_inference/run_curl_lora_inference.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Online diffusion LoRA inference via OpenAI-compatible chat API. + +SERVER="${SERVER:-http://localhost:8091}" +PROMPT="${PROMPT:-A piece of cheesecake}" + +LORA_PATH="${LORA_PATH:-}" +LORA_NAME="${LORA_NAME:-lora}" +LORA_SCALE="${LORA_SCALE:-1.0}" +LORA_INT_ID="${LORA_INT_ID:-}" + +HEIGHT="${HEIGHT:-1024}" +WIDTH="${WIDTH:-1024}" +NUM_INFERENCE_STEPS="${NUM_INFERENCE_STEPS:-50}" +SEED="${SEED:-42}" + +CURRENT_TIME=$(date +%Y%m%d%H%M%S) +OUTPUT="${OUTPUT:-lora_online_output_${CURRENT_TIME}.png}" + +if [ -z "$LORA_PATH" ]; then + echo "ERROR: LORA_PATH is required (must be a server-local path)." + exit 1 +fi + +echo "Generating image with LoRA..." +echo "Server: $SERVER" +echo "Prompt: $PROMPT" +echo "LoRA: name=$LORA_NAME id=${LORA_INT_ID:-auto} scale=$LORA_SCALE path=$LORA_PATH" +echo "Output: $OUTPUT" + +LORA_INT_ID_FIELD="" +if [ -n "$LORA_INT_ID" ]; then + LORA_INT_ID_FIELD=", \"int_id\": $LORA_INT_ID" +fi + +curl -s "$SERVER/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"messages\": [ + {\"role\": \"user\", \"content\": \"$PROMPT\"} + ], + \"extra_body\": { + \"height\": $HEIGHT, + \"width\": $WIDTH, + \"num_inference_steps\": $NUM_INFERENCE_STEPS, + \"seed\": $SEED, + \"lora\": { + \"name\": \"$LORA_NAME\", + \"local_path\": \"$LORA_PATH\", + \"scale\": $LORA_SCALE$LORA_INT_ID_FIELD + } + } + }" | jq -r '.choices[0].message.content[0].image_url.url' | sed 's/^data:image[^,]*,\s*//' | base64 -d > "$OUTPUT" + +if [ -f "$OUTPUT" ]; then + echo "Image saved to: $OUTPUT" + echo "Size: $(du -h "$OUTPUT" | cut -f1)" +else + echo "Failed to generate image" + exit 1 +fi diff --git a/examples/online_serving/lora_inference/run_server.sh b/examples/online_serving/lora_inference/run_server.sh new file mode 100755 index 00000000000..3233dd77397 --- /dev/null +++ b/examples/online_serving/lora_inference/run_server.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Online diffusion serving with vLLM-Omni (OpenAI-compatible API). + +MODEL="${MODEL:-stabilityai/stable-diffusion-3.5-medium}" +PORT="${PORT:-8091}" + +echo "Starting vLLM-Omni diffusion server..." +echo "Model: $MODEL" +echo "Port: $PORT" + +if [ -z "${VLLM_BIN:-}" ]; then + if command -v vllm-omni >/dev/null 2>&1; then + VLLM_BIN="vllm-omni" + else + VLLM_BIN="vllm" + fi +fi + +"$VLLM_BIN" serve "$MODEL" --omni \ + --port "$PORT" diff --git a/tests/diffusion/attention/test_flash_attn.py b/tests/diffusion/attention/test_flash_attn.py deleted file mode 100644 index 3f3862405ed..00000000000 --- a/tests/diffusion/attention/test_flash_attn.py +++ /dev/null @@ -1,290 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -Test script for FlashAttention backend with padding handling. - -This script tests two main scenarios: -1. Case 1: Comparing padded vs unpadded inputs for batch_size=1 -2. Case 2: Comparing FlashAttention and SDPA backends for batch_size=2 with padding -""" - -import pytest -import torch - -from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata -from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl -from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl - - -def create_attention_mask(batch_size: int, seq_len: int, valid_len: int, device: torch.device) -> torch.Tensor: - """ - Create attention mask where first valid_len tokens are valid (1) and rest are padding (0). - - Args: - batch_size: Batch size - seq_len: Total sequence length (including padding) - valid_len: Number of valid (non-padded) tokens - - Returns: - Attention mask of shape (batch_size, seq_len) - """ - mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) - mask[:, :valid_len] = True - return mask - - -def pad_tensor(tensor: torch.Tensor, target_seq_len: int, pad_value: float = 0.0) -> torch.Tensor: - """ - Pad tensor along sequence dimension (dim=1). - - Args: - tensor: Input tensor of shape (batch_size, seq_len, num_heads, head_dim) - target_seq_len: Target sequence length after padding - pad_value: Value to use for padding - - Returns: - Padded tensor of shape (batch_size, target_seq_len, num_heads, head_dim) - """ - batch_size, seq_len, num_heads, head_dim = tensor.shape - if target_seq_len <= seq_len: - return tensor - - padding = torch.full( - (batch_size, target_seq_len - seq_len, num_heads, head_dim), pad_value, dtype=tensor.dtype, device=tensor.device - ) - return torch.cat([tensor, padding], dim=1) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") -def test_padding_equivalence(): - """ - Case 1: Test that padded and unpadded inputs produce similar outputs. - - - Input A: batch_size=1, hidden_states (1, 48), encoder_hidden_states (1, 16) - Concatenated length: 64, NO attention_mask - - Input B: Same data but padded: hidden_states (1, 58), encoder_hidden_states (1, 26) - Concatenated length: 84, WITH attention_mask - - Expected: Output A and Output B should be very close. - """ - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Configuration - batch_size = 1 - hidden_seq_len = 48 - encoder_seq_len = 16 - pad_length = 10 - num_heads = 8 - head_dim = 64 - - # Initialize FlashAttention - fa_impl = FlashAttentionImpl( - num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False - ) - - # Create base tensors with random values (same for both A and B) - torch.manual_seed(42) - hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) - encoder_hidden_states_base = torch.randn( - batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype - ) - - # ========== Input A: Unpadded, no attention mask ========== - query_a = torch.cat([hidden_states_base, encoder_hidden_states_base], dim=1) - key_a = query_a.clone() - value_a = query_a.clone() - - attn_metadata_a = AttentionMetadata(attn_mask=None) - - output_a = fa_impl.forward(query=query_a, key=key_a, value=value_a, attn_metadata=attn_metadata_a) - - # ========== Input B: Padded with attention mask ========== - hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) - encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) - - query_b = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) - key_b = query_b.clone() - value_b = query_b.clone() - - # Create attention mask - attn_mask_b = torch.cat( - [ - create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), - create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), - ], - dim=1, - ) - - attn_metadata_b = AttentionMetadata(attn_mask=attn_mask_b) - - output_b = fa_impl.forward(query=query_b, key=key_b, value=value_b, attn_metadata=attn_metadata_b) - - # Extract non-padded portion from output_b - output_b_unpadded = torch.cat( - [ - output_b[:, :hidden_seq_len, :, :], - output_b[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], - ], - dim=1, - ) - - # Compare outputs - max_diff = torch.max(torch.abs(output_a - output_b_unpadded)).item() - mean_diff = torch.mean(torch.abs(output_a - output_b_unpadded)).item() - - print("\n=== Case 1: Padding Equivalence Test ===") - print(f"Output A shape: {output_a.shape}") - print(f"Output B shape: {output_b.shape}") - print(f"Output B unpadded shape: {output_b_unpadded.shape}") - print(f"Max absolute difference: {max_diff:.6f}") - print(f"Mean absolute difference: {mean_diff:.6f}") - - # Assert that outputs are close - # Using higher tolerance for bfloat16 - assert max_diff < 0.1, f"Max difference {max_diff} exceeds threshold 0.1" - assert mean_diff < 0.01, f"Mean difference {mean_diff} exceeds threshold 0.01" - - print("✓ Case 1 PASSED: Padded and unpadded outputs are very close!") - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA") -def test_fa_vs_sdpa(): - """ - Case 2: Compare FlashAttention and SDPA backends with padding. - - - batch_size=2 - - hidden_states: (2, 48) padded to (2, 58) - - encoder_hidden_states: (2, 16) padded to (2, 26) - - Concatenated length: 84 - - Compare FA and SDPA outputs - - Expected: FA and SDPA outputs should be very close. - """ - device = torch.device("cuda") - dtype = torch.bfloat16 - - # Configuration - batch_size = 2 - hidden_seq_len = 48 - encoder_seq_len = 16 - pad_length = 10 - num_heads = 8 - head_dim = 64 - - # Initialize both backends - fa_impl = FlashAttentionImpl( - num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False - ) - - sdpa_impl = SDPAImpl(num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False) - - # Create base tensors - torch.manual_seed(123) - hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype) - encoder_hidden_states_base = torch.randn( - batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype - ) - - # Pad tensors - hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length) - encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length) - - # Concatenate - query = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1) - key = query.clone() - value = query.clone() - - # Create attention mask - attn_mask = torch.cat( - [ - create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device), - create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device), - ], - dim=1, - ) - - attn_metadata = AttentionMetadata(attn_mask=attn_mask) - - # Run FlashAttention - output_fa = fa_impl.forward(query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata) - - # Run SDPA - # SDPA expects 4D attention mask: (batch_size, 1, seq_len, seq_len) or (batch_size, seq_len) - # For causal=False, we need to convert 2D mask to 4D - if attn_mask is not None: - # Expand mask for SDPA: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) - attn_mask_4d = attn_mask.unsqueeze(1).unsqueeze(2) - # Convert bool to float: True -> 0.0, False -> -inf - attn_mask_float = torch.zeros_like(attn_mask_4d, dtype=dtype) - attn_mask_float.masked_fill_(~attn_mask_4d, float("-inf")) - attn_metadata_sdpa = AttentionMetadata(attn_mask=attn_mask_float) - else: - attn_metadata_sdpa = AttentionMetadata(attn_mask=None) - - output_sdpa = sdpa_impl.forward( - query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata_sdpa - ) - - # Compare outputs (only compare valid regions) - output_fa_valid = torch.cat( - [ - output_fa[:, :hidden_seq_len, :, :], - output_fa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], - ], - dim=1, - ) - output_sdpa_valid = torch.cat( - [ - output_sdpa[:, :hidden_seq_len, :, :], - output_sdpa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :], - ], - dim=1, - ) - - max_diff = torch.max(torch.abs(output_fa_valid - output_sdpa_valid)).item() - mean_diff = torch.mean(torch.abs(output_fa_valid - output_sdpa_valid)).item() - - print("\n=== Case 2: FA vs SDPA Comparison ===") - print(f"Batch size: {batch_size}") - print(f"FA output shape: {output_fa.shape}") - print(f"SDPA output shape: {output_sdpa.shape}") - print(f"Max absolute difference (valid region): {max_diff:.6f}") - print(f"Mean absolute difference (valid region): {mean_diff:.6f}") - - # Assert that outputs are close - # Using higher tolerance for bfloat16 and different implementations - assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold 0.01" - assert mean_diff < 0.001, f"Mean difference {mean_diff} exceeds threshold 0.001" - - print("✓ Case 2 PASSED: FA and SDPA outputs are very close!") - - -if __name__ == "__main__": - print("Running FlashAttention Padding Tests...") - print("=" * 60) - - # Try to run CUDA tests - if torch.cuda.is_available(): - try: - print("\n[Running Case 1: Padding Equivalence for FA]") - test_padding_equivalence() - except Exception as e: - print(f"✗ Case 1 failed: {e}") - import traceback - - traceback.print_exc() - - try: - print("\n[Running Case 2: FA vs SDPA]") - test_fa_vs_sdpa() - except Exception as e: - print(f"✗ Case 2 failed: {e}") - import traceback - - traceback.print_exc() - else: - raise RuntimeError("CUDA is not available") - print("\n" + "=" * 60) - print("Test suite completed!") diff --git a/tests/diffusion/lora/test_base_linear.py b/tests/diffusion/lora/test_base_linear.py new file mode 100644 index 00000000000..42bdf6526a5 --- /dev/null +++ b/tests/diffusion/lora/test_base_linear.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from vllm_omni.diffusion.lora.layers.base_linear import DiffusionBaseLinearLayerWithLoRA + + +@dataclass +class _DummyLoRAConfig: + fully_sharded_loras: bool = False + + +class _DummyQuantMethod: + def __init__(self, weight: torch.Tensor): + self._weight = weight + + def apply(self, _base_layer, x: torch.Tensor, bias: torch.Tensor | None): + y = x @ self._weight.t() + if bias is not None: + y = y + bias + return y + + +def test_diffusion_base_linear_apply_multi_slice(): + # Build a fake diffusion LoRA layer with 2 slices and rank=2. + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 3 + out_slices = (2, 1) + rank = 2 + + # Base weight: identity-ish mapping to make base output easy to reason about. + base_weight = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + # Allocate stacked weights: (max_loras=1, 1, rank, in_dim) and (1, 1, out_slice, rank) + a0 = torch.zeros((1, 1, rank, in_dim)) + b0 = torch.zeros((1, 1, out_slices[0], rank)) + a1 = torch.zeros((1, 1, rank, in_dim)) + b1 = torch.zeros((1, 1, out_slices[1], rank)) + + # Slice 0: delta0 = (x @ A0.T) @ B0.T + A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3) + B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2) + a0[0, 0, :, :] = A0 + b0[0, 0, :, :] = B0 + + # Slice 1: delta1 = (x @ A1.T) @ B1.T + A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3) + B1 = torch.tensor([[2.0, 0.0]]) # (1, 2) + a1[0, 0, :, :] = A1 + b1[0, 0, :, :] = B1 + + layer.lora_a_stacked = (a0, a1) + layer.lora_b_stacked = (b0, b1) + layer.output_slices = out_slices + + x = torch.tensor([[1.0, 2.0, 3.0]]) + out = layer.apply(x) + + # Base output is identity: [1,2,3] + base_out = x @ base_weight.t() + # delta0: + # (x @ A0.T) = [1,2] + # [1,2] @ B0.T = [1,2] + delta0 = torch.tensor([[1.0, 2.0]]) + # delta1: + # (x @ A1.T) = [3,1] + # [3,1] @ B1.T = [6] + delta1 = torch.tensor([[6.0]]) + expected = torch.cat([base_out[:, :2] + delta0, base_out[:, 2:3] + delta1], dim=-1) + assert torch.allclose(out, expected) + + +def test_diffusion_base_linear_reset_lora_disables_fast_path(monkeypatch): + # Verify that after reset_lora(), apply() skips LoRA matmuls even if the + # LoRA tensors are still allocated and non-empty. + from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA + + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 2 + out_dim = 2 + rank = 1 + + base_weight = torch.eye(in_dim) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + a = torch.ones((1, 1, rank, in_dim)) + b = torch.tensor([[[[1.0], [2.0]]]]) # (1,1,out_dim,rank) + + layer.lora_a_stacked = (a,) + layer.lora_b_stacked = (b,) + layer.output_slices = (out_dim,) + layer._diffusion_lora_active_slices = (True,) + + x = torch.tensor([[1.0, 2.0]]) + out_active = layer.apply(x) + assert torch.allclose(out_active, torch.tensor([[4.0, 8.0]])) + + monkeypatch.setattr(BaseLinearLayerWithLoRA, "reset_lora", lambda self, index: None) + layer.reset_lora(0) + + assert layer._diffusion_lora_active_slices == (False,) + out_inactive = layer.apply(x) + assert torch.allclose(out_inactive, x) + + +def test_diffusion_base_linear_apply_respects_inactive_slices(): + # Build a fake diffusion LoRA layer with 2 slices and rank=2. + layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA) + layer.tp_size = 1 + layer.lora_config = _DummyLoRAConfig() + + in_dim = 3 + out_slices = (2, 1) + rank = 2 + + base_weight = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + layer.base_layer = type("Base", (), {})() + layer.base_layer.quant_method = _DummyQuantMethod(base_weight) + + a0 = torch.zeros((1, 1, rank, in_dim)) + b0 = torch.zeros((1, 1, out_slices[0], rank)) + a1 = torch.zeros((1, 1, rank, in_dim)) + b1 = torch.zeros((1, 1, out_slices[1], rank)) + + A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3) + B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2) + a0[0, 0, :, :] = A0 + b0[0, 0, :, :] = B0 + + A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3) + B1 = torch.tensor([[2.0, 0.0]]) # (1, 2) + a1[0, 0, :, :] = A1 + b1[0, 0, :, :] = B1 + + layer.lora_a_stacked = (a0, a1) + layer.lora_b_stacked = (b0, b1) + layer.output_slices = out_slices + layer._diffusion_lora_active_slices = (True, False) + + x = torch.tensor([[1.0, 2.0, 3.0]]) + out = layer.apply(x) + + # Only the first slice should be adapted. + expected = torch.tensor([[2.0, 4.0, 3.0]]) + assert torch.allclose(out, expected) diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py new file mode 100644 index 00000000000..84fafe3bc9e --- /dev/null +++ b/tests/diffusion/lora/test_lora_manager.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from vllm.lora.lora_weights import LoRALayerWeights +from vllm.lora.utils import get_supported_lora_modules +from vllm.model_executor.layers.linear import LinearBase + +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager +from vllm_omni.lora.request import LoRARequest + + +class _DummyLoRALayer: + def __init__(self, n_slices: int, output_slices: tuple[int, ...]): + self.n_slices = n_slices + self.output_slices = output_slices + self.set_calls: list[ + tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor] + ] = [] + self.reset_calls: int = 0 + + def set_lora(self, index: int, lora_a, lora_b): + assert index == 0 + self.set_calls.append((lora_a, lora_b)) + + def reset_lora(self, index: int): + assert index == 0 + self.reset_calls += 1 + + +class _FakeLinearBase(LinearBase): + def __init__(self): + torch.nn.Module.__init__(self) + + +def test_lora_manager_supported_modules_are_stable_with_wrapped_layers(monkeypatch): + # Simulate a pipeline that already contains LoRA wrappers where the original + # LinearBase is nested under ".base_layer". + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + pipeline = torch.nn.Module() + pipeline.transformer = torch.nn.Module() + pipeline.transformer.foo = _DummyBaseLayerWithLoRA(_FakeLinearBase()) + + # vLLM helper would see only the nested LinearBase and yield "base_layer". + assert get_supported_lora_modules(pipeline) == ["base_layer"] + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + assert "foo" in manager._supported_lora_modules + assert "base_layer" not in manager._supported_lora_modules + + +def test_lora_manager_replace_layers_does_not_rewrap_base_layer(monkeypatch): + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs): + if isinstance(layer, _FakeLinearBase): + return _DummyBaseLayerWithLoRA(layer) + return layer + + replace_calls: list[str] = [] + + def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module): + replace_calls.append(module_name) + setattr(root, module_name, submodule) + + monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion) + monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule) + + pipeline = torch.nn.Module() + pipeline.transformer = torch.nn.Module() + pipeline.transformer.foo = _FakeLinearBase() + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + peft_helper = type("_PH", (), {"r": 1})() + + manager._replace_layers_with_lora(peft_helper) + manager._replace_layers_with_lora(peft_helper) + + # Only the top-level layer should have been replaced; nested ".base_layer" + # must be skipped to avoid nesting LoRA wrappers. + assert replace_calls == ["foo"] + + +def test_lora_manager_replaces_packed_layer_when_targeting_sublayers(monkeypatch): + import vllm_omni.diffusion.lora.manager as manager_mod + + class _DummyBaseLayerWithLoRA(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module): + super().__init__() + self.base_layer = base_layer + + monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA) + + def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs): + return _DummyBaseLayerWithLoRA(layer) + + replace_calls: list[str] = [] + + def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module): + replace_calls.append(module_name) + setattr(root, module_name, submodule) + + monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion) + monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule) + + pipeline = torch.nn.Module() + pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]} + pipeline.transformer = torch.nn.Module() + pipeline.transformer.to_qkv = _FakeLinearBase() + + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + # Treat the dummy layer as a packed 3-slice projection so the manager uses + # `packed_modules_mapping` to decide replacement based on target_modules. + monkeypatch.setattr(manager, "_get_packed_modules_list", lambda _module: ["q", "k", "v"]) + + peft_helper = type("_PH", (), {"r": 1, "target_modules": ["to_q"]})() + manager._replace_layers_with_lora(peft_helper) + + assert replace_calls == ["to_qkv"] + + +def test_lora_manager_activates_fused_lora_on_packed_layer(): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1)) + manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer} + + rank = 2 + A = torch.ones((rank, 4)) + B = torch.arange(0, sum(packed_layer.output_slices) * rank, dtype=torch.bfloat16).view(-1, rank) + lora = LoRALayerWeights( + module_name="transformer.blocks.0.attn.to_qkv", + rank=rank, + lora_alpha=rank, + lora_a=A, + lora_b=B, + ) + manager._registered_adapters = { + 7: type( + "LM", + (), + { + "id": 7, + "loras": {"transformer.blocks.0.attn.to_qkv": lora}, + "get_lora": lambda self, k: self.loras.get(k), + }, + )() + } + manager._adapter_scales = {7: 0.5} + + manager._activate_adapter(7) + + assert packed_layer.reset_calls == 0 + assert len(packed_layer.set_calls) == 1 + lora_a_list, lora_b_list = packed_layer.set_calls[0] + assert isinstance(lora_a_list, list) + assert isinstance(lora_b_list, list) + assert len(lora_a_list) == 3 + assert len(lora_b_list) == 3 + assert all(torch.allclose(a, A) for a in lora_a_list) + # B should be split into 3 slices and scaled. + b0, b1, b2 = lora_b_list + assert b0.shape[0] == 2 and b1.shape[0] == 1 and b2.shape[0] == 1 + assert torch.allclose(torch.cat([b0, b1, b2], dim=0), B * 0.5) + + +def test_lora_manager_activates_packed_lora_from_sublayers(): + pipeline = torch.nn.Module() + pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]} + manager = DiffusionLoRAManager( + pipeline=pipeline, + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=1, + ) + + packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1)) + manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer} + + rank = 2 + loras: dict[str, LoRALayerWeights] = {} + for name, out_dim in zip(["to_q", "to_k", "to_v"], [2, 1, 1]): + loras[f"transformer.blocks.0.attn.{name}"] = LoRALayerWeights( + module_name=f"transformer.blocks.0.attn.{name}", + rank=rank, + lora_alpha=rank, + lora_a=torch.ones((rank, 4)) * (1 if name == "to_q" else 2), + lora_b=torch.ones((out_dim, rank)) * (3 if name == "to_q" else 4), + ) + + manager._registered_adapters = { + 1: type("LM", (), {"id": 1, "loras": loras, "get_lora": lambda self, k: self.loras.get(k)})() + } + manager._adapter_scales = {1: 2.0} + + manager._activate_adapter(1) + + assert packed_layer.reset_calls == 0 + assert len(packed_layer.set_calls) == 1 + lora_a_list, lora_b_list = packed_layer.set_calls[0] + assert isinstance(lora_a_list, list) + assert isinstance(lora_b_list, list) + assert len(lora_a_list) == 3 + assert len(lora_b_list) == 3 + # Scale should apply to B only. + assert torch.allclose(lora_b_list[0], torch.ones((2, rank)) * 3 * 2.0) + assert torch.allclose(lora_b_list[1], torch.ones((1, rank)) * 4 * 2.0) + assert torch.allclose(lora_b_list[2], torch.ones((1, rank)) * 4 * 2.0) + + +def _dummy_lora_request(adapter_id: int) -> LoRARequest: + return LoRARequest( + lora_name=f"adapter_{adapter_id}", + lora_int_id=adapter_id, + lora_path=f"/tmp/adapter_{adapter_id}", + ) + + +def test_lora_manager_evicts_lru_adapter_when_cache_full(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + req1 = _dummy_lora_request(1) + req2 = _dummy_lora_request(2) + req3 = _dummy_lora_request(3) + + manager.set_active_adapter(req1, lora_scale=1.0) + manager.set_active_adapter(req2, lora_scale=1.0) + + # Touch adapter 1 so adapter 2 becomes LRU. + manager.set_active_adapter(req1, lora_scale=1.0) + + manager.set_active_adapter(req3, lora_scale=1.0) + + assert set(manager.list_adapters()) == {1, 3} + + +def test_lora_manager_does_not_evict_pinned_adapter(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) + assert manager.pin_adapter(1) + + manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) + manager.set_active_adapter(_dummy_lora_request(3), lora_scale=1.0) + + assert set(manager.list_adapters()) == {1, 3} + + +def test_lora_manager_warns_when_all_adapters_pinned(monkeypatch): + manager = DiffusionLoRAManager( + pipeline=torch.nn.Module(), + device=torch.device("cpu"), + dtype=torch.bfloat16, + max_cached_adapters=2, + ) + + def _fake_load(_req: LoRARequest): + lora_model = type("LM", (), {"id": _req.lora_int_id})() + peft_helper = type("PH", (), {})() + return lora_model, peft_helper + + monkeypatch.setattr(manager, "_load_adapter", _fake_load) + monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None) + monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None) + + manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0) + manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0) + + assert manager.pin_adapter(1) + assert manager.pin_adapter(2) + + manager.max_cached_adapters = 1 + manager._evict_if_needed() + + assert set(manager.list_adapters()) == {1, 2} diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py new file mode 100644 index 00000000000..1761f253c6a --- /dev/null +++ b/tests/e2e/offline_inference/test_diffusion_lora.py @@ -0,0 +1,138 @@ +import json +import os +import sys +from pathlib import Path + +import pytest +import torch +from safetensors.torch import save_file + +from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.utils.platform_utils import is_npu + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + + +# This test is specific to Z-Image LoRA behavior. Keep it focused on a single +# model to reduce runtime and avoid extra downloads. +models = ["Tongyi-MAI/Z-Image-Turbo"] + +# NPU still can't run Tongyi-MAI/Z-Image-Turbo properly. +if is_npu(): + pytest.skip("Tongyi-MAI/Z-Image-Turbo is not supported on NPU yet.", allow_module_level=True) + + +@pytest.mark.parametrize("model_name", models) +def test_diffusion_model(model_name: str, tmp_path: Path): + def _extract_images(outputs: list[OmniRequestOutput]): + if not outputs: + raise ValueError("Empty outputs from Omni.generate()") + first_output = outputs[0] + assert first_output.final_output_type == "image" + if not hasattr(first_output, "request_output") or not first_output.request_output: + raise ValueError("No request_output found in OmniRequestOutput") + + req_out = first_output.request_output[0] + if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): + raise ValueError("Invalid request_output structure or missing 'images' key") + return req_out.images + + def _write_zimage_lora(adapter_dir: Path) -> str: + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default (see ZImageTransformer2DModel). + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V). With tp=1 and n_kv_heads==n_heads in Z-Image, + # each slice is `dim`, so total out dim is `3 * dim`. + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + # Apply a visible delta to the Q slice only to keep the perturbation bounded. + lora_b[:dim, 0] = 0.1 + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + return str(adapter_dir) + + m = Omni(model=model_name) + try: + # high resolution may cause OOM on L4 + height = 256 + width = 256 + prompt = "a photo of a cat sitting on a laptop keyboard" + + outputs = m.generate( + prompt, + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + ) + images = _extract_images(outputs) + + assert len(images) == 1 + # check image size + assert images[0].width == width + assert images[0].height == height + + # Real LoRA E2E: generate again with a real on-disk PEFT adapter and + # verify that output changes. + if model_name == "Tongyi-MAI/Z-Image-Turbo": + from vllm_omni.lora.request import LoRARequest + from vllm_omni.lora.utils import stable_lora_int_id + + lora_dir = _write_zimage_lora(tmp_path / "zimage_lora") + lora_request = LoRARequest( + lora_name="test", + lora_int_id=stable_lora_int_id(lora_dir), + lora_path=lora_dir, + ) + outputs_lora = m.generate( + prompt, + height=height, + width=width, + num_inference_steps=2, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), + num_outputs_per_prompt=1, + lora_request=lora_request, + lora_scale=2.0, + ) + images_lora = _extract_images(outputs_lora) + assert len(images_lora) == 1 + assert images_lora[0].width == width + assert images_lora[0].height == height + + import numpy as np + + diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean() + assert diff > 0.0 + finally: + m.close() diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py new file mode 100644 index 00000000000..38d3c9b897d --- /dev/null +++ b/tests/e2e/online_serving/test_images_generations_lora.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +E2E online serving test for /v1/images/generations with per-request LoRA. + +This validates: +- The API server accepts a per-request `lora` object in the Images API payload. +- LoRA can be switched per request (adapter A -> adapter B -> no LoRA). +- Output correctness is asserted using a small image slice with tolerance. +""" + +import base64 +import json +import os +import signal +import subprocess +import sys +import time +from io import BytesIO +from pathlib import Path + +import numpy as np +import pytest +import requests +import torch +from PIL import Image +from safetensors.torch import save_file +from vllm.utils.network_utils import get_open_port + +from vllm_omni.utils.platform_utils import is_npu + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODEL = "Tongyi-MAI/Z-Image-Turbo" + + +PROMPT = "a photo of a cat sitting on a laptop keyboard" +SIZE = "256x256" +SEED = 42 + + +class OmniServer: + """Omniserver for vLLM-Omni tests.""" + + def __init__( + self, + model: str, + serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + ) -> None: + self.model = model + self.serve_args = serve_args + self.env_dict = env_dict + self.proc: subprocess.Popen | None = None + self.host = "127.0.0.1" + self.port = get_open_port() + + def _start_server(self) -> None: + env = os.environ.copy() + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if self.env_dict is not None: + env.update(self.env_dict) + + cmd = [ + sys.executable, + "-m", + "vllm_omni.entrypoints.cli.main", + "serve", + self.model, + "--omni", + "--host", + self.host, + "--port", + str(self.port), + ] + self.serve_args + + print(f"Launching OmniServer with: {' '.join(cmd)}") + self.proc = subprocess.Popen( + cmd, + env=env, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # vllm-omni root + start_new_session=True, + ) + + # Wait for server to be ready. + max_wait = 600 + url = f"http://{self.host}:{self.port}/v1/models" + start_time = time.time() + while time.time() - start_time < max_wait: + try: + resp = requests.get(url, headers={"Authorization": "Bearer EMPTY"}, timeout=10) + if resp.status_code == 200: + print(f"Server ready on {self.host}:{self.port}") + return + except Exception: + pass + time.sleep(2) + + raise RuntimeError(f"Server failed to become ready within {max_wait} seconds") + + def __enter__(self): + self._start_server() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.proc is None: + return + try: + os.killpg(self.proc.pid, signal.SIGTERM) + except ProcessLookupError: + pass + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + try: + os.killpg(self.proc.pid, signal.SIGKILL) + except ProcessLookupError: + pass + self.proc.wait() + + +@pytest.fixture(scope="module") +def omni_server(): + if is_npu(): + pytest.skip("Tongyi-MAI/Z-Image-Turbo is not supported on NPU yet.") + with OmniServer(MODEL, ["--num-gpus", "1"]) as server: + yield server + + +def _write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0): + adapter_dir.mkdir(parents=True, exist_ok=True) + + # Z-Image transformer uses dim=3840 by default. + dim = 3840 + module_name = "transformer.layers.0.attention.to_qkv" + rank = 1 + + lora_a = torch.zeros((rank, dim), dtype=torch.float32) + lora_a[0, 0] = 1.0 + + # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1). + lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32) + if q_scale: + lora_b[:dim, 0] = q_scale + if k_scale: + lora_b[dim : 2 * dim, 0] = k_scale + if v_scale: + lora_b[2 * dim :, 0] = v_scale + + save_file( + { + f"base_model.model.{module_name}.lora_A.weight": lora_a, + f"base_model.model.{module_name}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps( + { + "r": rank, + "lora_alpha": rank, + "target_modules": [module_name], + } + ), + encoding="utf-8", + ) + + +def _post_images(server: OmniServer, payload: dict) -> Image.Image: + url = f"http://{server.host}:{server.port}/v1/images/generations" + resp = requests.post(url, json=payload, headers={"Authorization": "Bearer EMPTY"}, timeout=900) + resp.raise_for_status() + data = resp.json() + b64 = data["data"][0]["b64_json"] + img_bytes = base64.b64decode(b64) + img = Image.open(BytesIO(img_bytes)) + img.load() + return img.convert("RGB") + + +def _image_blue_tail_slice(img: Image.Image) -> np.ndarray: + arr = np.asarray(img, dtype=np.uint8) + assert arr.ndim == 3 and arr.shape[-1] == 3 + tail = arr[-3:, -3:, -1].astype(np.float32) + assert tail.shape == (3, 3) + return tail + + +def _assert_slice_close(actual: np.ndarray, expected: np.ndarray, *, label: str) -> None: + assert actual.shape == (3, 3) + assert expected.shape == (3, 3) + diff = np.abs(actual - expected) + max_diff = float(diff.max()) + mean_diff = float(diff.mean()) + # NOTE: Different attention backends / torch.compile can introduce small + # floating-point drift that shows up as a few LSBs in uint8 pixels. Keep + # the reset check tolerant but bounded to avoid flaky CI. + assert max_diff <= 5.0 and mean_diff <= 3.0, ( + f"{label} slice mismatch (max={max_diff:.1f}, mean={mean_diff:.1f}): {actual.tolist()}" + ) + + +def _assert_slice_diff(actual: np.ndarray, baseline: np.ndarray, *, label: str) -> None: + assert actual.shape == (3, 3) + assert baseline.shape == (3, 3) + diff = np.abs(actual - baseline).mean() + assert diff > 0.1, f"{label} slice diff too small: {diff} ({actual.tolist()} vs {baseline.tolist()})" + + +def _basic_payload() -> dict: + return { + "prompt": PROMPT, + "n": 1, + "size": SIZE, + "num_inference_steps": 2, + "guidance_scale": 0.0, + "seed": SEED, + } + + +def test_images_generations_per_request_lora_switching(omni_server: OmniServer, tmp_path: Path) -> None: + # Base generation. + base_img = _post_images(omni_server, _basic_payload()) + base_slice = _image_blue_tail_slice(base_img) + + # Adapter A: apply delta to Q slice only. + lora_a_dir = tmp_path / "zimage_lora_a" + _write_zimage_lora(lora_a_dir, q_scale=0.1) + payload_a = _basic_payload() + payload_a["lora"] = {"name": "a", "path": str(lora_a_dir), "scale": 2.0} + img_a = _post_images(omni_server, payload_a) + a_slice = _image_blue_tail_slice(img_a) + _assert_slice_diff(a_slice, base_slice, label="lora_a_vs_base") + a_vs_base = float(np.abs(a_slice - base_slice).mean()) + + # Adapter B: apply delta to K slice only (should differ from adapter A). + lora_b_dir = tmp_path / "zimage_lora_b" + _write_zimage_lora(lora_b_dir, k_scale=0.1) + payload_b = _basic_payload() + payload_b["lora"] = {"name": "b", "path": str(lora_b_dir), "scale": 2.0} + img_b = _post_images(omni_server, payload_b) + b_slice = _image_blue_tail_slice(img_b) + _assert_slice_diff(b_slice, base_slice, label="lora_b_vs_base") + _assert_slice_diff(b_slice, a_slice, label="lora_b_vs_lora_a") + b_vs_base = float(np.abs(b_slice - base_slice).mean()) + b_vs_a = float(np.abs(b_slice - a_slice).mean()) + + # Ensure switching back to no-LoRA restores the base output. + base_img_2 = _post_images(omni_server, _basic_payload()) + base_slice_2 = _image_blue_tail_slice(base_img_2) + _assert_slice_close(base_slice_2, base_slice, label="base_after_reset") + base_reset = float(np.abs(base_slice_2 - base_slice).mean()) + + # Ensure LoRA effects are clearly above the baseline drift. + min_delta = base_reset + 0.5 + assert a_vs_base > min_delta, f"lora_a_vs_base drift too small: {a_vs_base} <= {min_delta}" + assert b_vs_base > min_delta, f"lora_b_vs_base drift too small: {b_vs_base} <= {min_delta}" + assert b_vs_a > min_delta, f"lora_b_vs_lora_a drift too small: {b_vs_a} <= {min_delta}" diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index cc7e8132e50..e2db6f4273c 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -2,8 +2,10 @@ Configuration module for vLLM-Omni. """ +from vllm_omni.config.lora import LoRAConfig from vllm_omni.config.model import OmniModelConfig __all__ = [ "OmniModelConfig", + "LoRAConfig", ] diff --git a/vllm_omni/config/lora.py b/vllm_omni/config/lora.py new file mode 100644 index 00000000000..00aba2e16b6 --- /dev/null +++ b/vllm_omni/config/lora.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# for now, it suffices to use vLLM's implementation directly +# as this is a user-facing variable, defined here to so that user can directly import LoRAConfig from vllm_omni +from vllm.config.lora import LoRAConfig + +__all__ = ["LoRAConfig"] diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index da0f5716b96..a6e712bff6a 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -20,10 +20,7 @@ get_hf_text_config, get_pooling_config, ) -from vllm.transformers_utils.gguf_utils import ( - is_gguf, - maybe_patch_hf_config_from_gguf, -) +from vllm.transformers_utils.gguf_utils import is_gguf, maybe_patch_hf_config_from_gguf from vllm.transformers_utils.utils import maybe_model_redirect from vllm.v1.attention.backends.registry import AttentionBackendEnum diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f85903b9bd4..1a98462aaea 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -276,12 +276,9 @@ class OmniDiffusionConfig: # pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) # LoRA parameters - # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated. lora_path: str | None = None - lora_nickname: str = "default" # for swapping adapters in the pipeline - # can restrict layers to adapt, e.g. ["q_proj"] - # Will adapt only q, k, v, o by default. - lora_target_modules: list[str] | None = None + lora_scale: float = 1.0 + max_cpu_loras: int | None = None output_type: str = "pil" @@ -446,11 +443,24 @@ def __post_init__(self): # If it's neither dict nor DiffusionCacheConfig, convert to empty config self.cache_config = DiffusionCacheConfig() + if self.max_cpu_loras is None: + self.max_cpu_loras = 1 + elif self.max_cpu_loras < 1: + raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + def update_multimodal_support(self) -> None: self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} @classmethod def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": + # Backwards-compatibility: older callers may use a diffusion-specific + # "static_lora_scale" kwarg. Normalize it to the canonical "lora_scale" + # before constructing the dataclass to avoid TypeError on unknown fields. + if "static_lora_scale" in kwargs: + if "lora_scale" not in kwargs: + kwargs["lora_scale"] = kwargs["static_lora_scale"] + kwargs.pop("static_lora_scale", None) + # Check environment variable as fallback for cache_backend # Support both old DIFFUSION_CACHE_ADAPTER and new DIFFUSION_CACHE_BACKEND for backwards compatibility if "cache_backend" not in kwargs: diff --git a/vllm_omni/diffusion/lora/__init__.py b/vllm_omni/diffusion/lora/__init__.py new file mode 100644 index 00000000000..353a2c6bee6 --- /dev/null +++ b/vllm_omni/diffusion/lora/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager + +__all__ = ["DiffusionLoRAManager"] diff --git a/vllm_omni/diffusion/lora/layers/__init__.py b/vllm_omni/diffusion/lora/layers/__init__.py new file mode 100644 index 00000000000..ab501f105f3 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .base_linear import DiffusionBaseLinearLayerWithLoRA +from .column_parallel_linear import ( + DiffusionColumnParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, +) +from .replicated_linear import DiffusionReplicatedLinearWithLoRA +from .row_parallel_linear import DiffusionRowParallelLinearWithLoRA + +__all__ = [ + "DiffusionBaseLinearLayerWithLoRA", + "DiffusionReplicatedLinearWithLoRA", + "DiffusionColumnParallelLinearWithLoRA", + "DiffusionMergedColumnParallelLinearWithLoRA", + "DiffusionRowParallelLinearWithLoRA", + "DiffusionQKVParallelLinearWithLoRA", + "DiffusionMergedQKVParallelLinearWithLoRA", +] diff --git a/vllm_omni/diffusion/lora/layers/base_linear.py b/vllm_omni/diffusion/lora/layers/base_linear.py new file mode 100644 index 00000000000..fe32868d083 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/base_linear.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA + + +class DiffusionBaseLinearLayerWithLoRA(BaseLinearLayerWithLoRA): + """ + Diffusion-specific base that overrides apply() to use direct torch matmul + instead of punica_wrapper. + + punica_wrapper is used to hold multiple LoRA slots and slices efficiently. + + This matches the semantics of PunicaWrapperGPU.add_lora_linear(): + - Shrink: buffer = (x @ lora_a.T) + - Expand: y += buffer @ lora_b.T + + All other functionality (weight management, TP slicing, forward logic) + is inherited from vLLM's BaseLinearLayerWithLoRA. + """ + + def create_lora_weights( + self, + max_loras: int, + lora_config, + model_config=None, + ) -> None: + super().create_lora_weights(max_loras, lora_config, model_config) + # Keep a direct reference for attribute forwarding: `base_layer` is a + # registered submodule (stored under `_modules`), so direct access via + # `object.__getattribute__` will not find it. We stash a ref in + # `__dict__` for robust lookups in `__getattr__`. + modules = object.__getattribute__(self, "_modules") + base_layer = modules.get("base_layer") or object.__getattribute__(self, "__dict__").get("base_layer") + object.__setattr__(self, "_diffusion_base_layer_ref", base_layer) + n_slices = getattr(self, "n_slices", 1) + self._diffusion_lora_active_slices = (False,) * int(n_slices) + + def reset_lora(self, index: int): + super().reset_lora(index) + n_slices = getattr(self, "n_slices", 1) + self._diffusion_lora_active_slices = (False,) * int(n_slices) + + def set_lora( + self, + index: int, + lora_a: torch.Tensor | list[torch.Tensor | None], + lora_b: torch.Tensor | list[torch.Tensor | None], + ): + super().set_lora(index, lora_a, lora_b) # type: ignore[arg-type] + + n_slices = getattr(self, "n_slices", 1) + if isinstance(lora_a, list) or isinstance(lora_b, list): + assert isinstance(lora_a, list) + assert isinstance(lora_b, list) + active_slices = [] + for a_i, b_i in zip(lora_a[:n_slices], lora_b[:n_slices]): + active_slices.append(a_i is not None and b_i is not None) + if len(active_slices) < n_slices: + active_slices.extend([False] * (n_slices - len(active_slices))) + self._diffusion_lora_active_slices = tuple(active_slices) + else: + # Single-slice layer. + self._diffusion_lora_active_slices = (True,) + + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + """ + override: Use simple matmul instead of punica_wrapper.add_lora_linear(). + + This matches the exact computation in PunicaWrapperGPU.add_lora_linear() + for the single-LoRA case. For packed projections (e.g. fused QKV), we + apply LoRA per-slice using `output_slices`. + """ + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + if not hasattr(self, "lora_a_stacked") or not hasattr(self, "lora_b_stacked"): + return output + if not self.lora_a_stacked or not self.lora_b_stacked: + return output + # Fast path: if no LoRA is active for this layer, skip matmuls. + active_slices = getattr(self, "_diffusion_lora_active_slices", None) + if active_slices is not None and not any(active_slices): + return output + + # In fully-sharded LoRA mode, vLLM uses an all-gather between shrink and + # expand for ColumnParallelLinear variants. This diffusion path doesn't + # implement that communication yet. + if getattr(self, "lora_config", None) is not None: + if self.lora_config.fully_sharded_loras and self.tp_size > 1: + raise NotImplementedError( + "Diffusion LoRA apply() does not support fully_sharded_loras with tensor parallelism yet." + ) + + original_shape = output.shape + x_flat = x.reshape(-1, x.shape[-1]) + y_flat = output.reshape(-1, output.shape[-1]) + + output_slices = getattr(self, "output_slices", None) + if output_slices is None: + # Fallback: infer slice sizes from the allocated tensors. + output_slices = tuple(lora_b.shape[2] for lora_b in self.lora_b_stacked) + + if len(output_slices) != len(self.lora_a_stacked) or len(output_slices) != len(self.lora_b_stacked): + raise RuntimeError( + "LoRA slice metadata mismatch: " + f"output_slices={len(output_slices)}, " + f"lora_a_stacked={len(self.lora_a_stacked)}, " + f"lora_b_stacked={len(self.lora_b_stacked)}" + ) + + offset = 0 + for slice_idx, slice_size in enumerate(output_slices): + if active_slices is not None and slice_idx < len(active_slices) and not active_slices[slice_idx]: + offset += slice_size + continue + + A = self.lora_a_stacked[slice_idx][0, 0, :, :] # (rank, in_dim) + B = self.lora_b_stacked[slice_idx][0, 0, :, :] # (out_dim, rank) + + if A.numel() == 0 or B.numel() == 0: + offset += slice_size + continue + + # LoRA shrink & expand as in add_lora_linear(): + # buffer = (x @ A.T) + # y += buffer @ B.T + delta = (x_flat @ A.t()) @ B.t() + y_flat[:, offset : offset + slice_size] = y_flat[:, offset : offset + slice_size] + delta + offset += slice_size + + return y_flat.view(original_shape) + + def __getattr__(self, name: str): + # The diffusion model implementations may access attributes directly + # from linear layers (e.g. QKVParallelLinear.num_heads). vLLM's LoRA + # wrappers don't forward these attributes by default, so we delegate + # missing attribute lookups to the underlying base_layer. + try: + return super().__getattr__(name) + except AttributeError as exc: + base_layer = object.__getattribute__(self, "__dict__").get("_diffusion_base_layer_ref") + if base_layer is None: + base_layer = object.__getattribute__(self, "_modules").get("base_layer") + if base_layer is None: + raise exc + try: + return getattr(base_layer, name) + except AttributeError: + raise exc diff --git a/vllm_omni/diffusion/lora/layers/column_parallel_linear.py b/vllm_omni/diffusion/lora/layers/column_parallel_linear.py new file mode 100644 index 00000000000..27ac94e61ed --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/column_parallel_linear.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, +) + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionColumnParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + ColumnParallelLinearWithLoRA, +): + """ + Diffusion ColumnParallelLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionMergedColumnParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + MergedColumnParallelLinearWithLoRA, +): + """ + Diffusion MergedColumnParallelLinear (gate_up_proj) with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionQKVParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + QKVParallelLinearWithLoRA, +): + """ + Diffusion QKVParallelLinear with single LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass + + +class DiffusionMergedQKVParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + MergedQKVParallelLinearWithLoRA, +): + """ + Diffusion MergedQKVParallelLinear (to_qkv) with 3 LoRAs. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/layers/replicated_linear.py b/vllm_omni/diffusion/lora/layers/replicated_linear.py new file mode 100644 index 00000000000..e6574a04f69 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/replicated_linear.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionReplicatedLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + ReplicatedLinearWithLoRA, +): + """ + Diffusion ReplicatedLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/layers/row_parallel_linear.py b/vllm_omni/diffusion/lora/layers/row_parallel_linear.py new file mode 100644 index 00000000000..ac211909213 --- /dev/null +++ b/vllm_omni/diffusion/lora/layers/row_parallel_linear.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.lora.layers.row_parallel_linear import RowParallelLinearWithLoRA + +from .base_linear import DiffusionBaseLinearLayerWithLoRA + + +class DiffusionRowParallelLinearWithLoRA( + DiffusionBaseLinearLayerWithLoRA, + RowParallelLinearWithLoRA, +): + """ + Diffusion RowParallelLinear with LoRA. + Prioritize apply() in DiffusionBaseLinearLayerWithLoRA + """ + + pass diff --git a/vllm_omni/diffusion/lora/manager.py b/vllm_omni/diffusion/lora/manager.py new file mode 100644 index 00000000000..7fad1b9e758 --- /dev/null +++ b/vllm_omni/diffusion/lora/manager.py @@ -0,0 +1,631 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from collections import OrderedDict + +import torch +import torch.nn as nn +from vllm.logger import init_logger +from vllm.lora.layers import BaseLayerWithLoRA +from vllm.lora.lora_model import LoRAModel +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import ( + get_adapter_absolute_path, + get_supported_lora_modules, + replace_submodule, +) +from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear + +from vllm_omni.config.lora import LoRAConfig +from vllm_omni.diffusion.lora.utils import ( + _expand_expected_modules_for_packed_layers, + _match_target_modules, + from_layer_diffusion, +) +from vllm_omni.lora.utils import stable_lora_int_id + +logger = init_logger(__name__) + + +class DiffusionLoRAManager: + """Manager for LoRA adapters in diffusion models. + + Reuses vLLM's LoRA infrastructure, adapted for diffusion pipelines. + Uses LRU cache management similar to LRUCacheLoRAModelManager. + """ + + def __init__( + self, + pipeline: nn.Module, + device: torch.device, + dtype: torch.dtype, + max_cached_adapters: int = 1, + lora_path: str | None = None, + lora_scale: float = 1.0, + ): + """ + Initialize the DiffusionLoRAManager. + + Args: + max_cached_adapters: Maximum number of LoRA adapters to keep in the + CPU-side cache (LRU). This mirrors vLLM's `max_cpu_loras` and is + exposed to users via `OmniDiffusionConfig.max_cpu_loras`. + """ + self.pipeline = pipeline + self.device = device + self.dtype = dtype + + # Cache supported/expected module suffixes once, before any layer + # replacement happens. After LoRA layers are injected, the original + # LinearBase layers become submodules named "*.base_layer", and calling + # vLLM's get_supported_lora_modules() again would incorrectly yield + # "base_layer" instead of the real target module suffixes. + self._supported_lora_modules = self._compute_supported_lora_modules() + self._packed_modules_mapping = self._compute_packed_modules_mapping() + self._expected_lora_modules = _expand_expected_modules_for_packed_layers( + self._supported_lora_modules, + self._packed_modules_mapping, + ) + + # LRU-style cache management + self.max_cached_adapters = max_cached_adapters # max_cpu_loras + self._registered_adapters: dict[int, LoRAModel] = {} # adapter_id -> LoRAModel + self._active_adapter_id: int | None = None + self._adapter_scales: dict[int, float] = {} # adapter_id -> external scale + + # LRU cache tracking (adapter_id -> last_used_time) + self._adapter_access_order: OrderedDict[int, float] = OrderedDict() + # Pinned adapters are not evicted + self._pinned_adapters: set[int] = set() + + # track replaced modules + # key: full module name (component.module.path); value: LoRA layer + self._lora_modules: dict[str, BaseLayerWithLoRA] = {} + # Track the maximum LoRA rank we've allocated buffers for. + self._max_lora_rank: int = 0 + + logger.info( + "Initializing DiffusionLoRAManager: device=%s, dtype=%s, max_cached_adapters=%d, static_lora_path=%s", + device, + dtype, + max_cached_adapters, + lora_path, + ) + + if lora_path is not None: + logger.info("Loading LoRA during initialization from %s with scale %.2f", lora_path, lora_scale) + init_request = LoRARequest( + lora_name="static", + lora_int_id=stable_lora_int_id(lora_path), + lora_path=lora_path, + ) + self.set_active_adapter(init_request, lora_scale) + + def _compute_supported_lora_modules(self) -> set[str]: + """Compute supported LoRA module suffixes for this pipeline. + + vLLM's get_supported_lora_modules() returns suffixes for LinearBase + modules. After this manager replaces layers with BaseLayerWithLoRA + wrappers, those LinearBase modules become nested under ".base_layer", + which would cause get_supported_lora_modules() to return "base_layer". + To make adapter loading stable across multiple adapters, we also accept + suffixes from existing BaseLayerWithLoRA wrappers and drop "base_layer" + when appropriate. + """ + supported = set(get_supported_lora_modules(self.pipeline)) + + has_lora_wrappers = False + for name, module in self.pipeline.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + has_lora_wrappers = True + supported.add(name.split(".")[-1]) + + if has_lora_wrappers: + supported.discard("base_layer") + + return supported + + def _compute_packed_modules_mapping(self) -> dict[str, list[str]]: + """Collect packed->sublayer mappings from the diffusion model. + + vLLM models declare `packed_modules_mapping` on the model class. For + diffusion pipelines, we attach the same mapping on the transformer + module(s) that implement packed (fused) projections, so LoRA loading can + accept checkpoints trained against the logical sub-projections. + """ + mapping: dict[str, list[str]] = {} + for module in self.pipeline.modules(): + packed = getattr(module, "packed_modules_mapping", None) + if not isinstance(packed, dict): + continue + for packed_name, sub_names in packed.items(): + if not isinstance(packed_name, str) or not packed_name: + continue + if not isinstance(sub_names, (list, tuple)) or not all(isinstance(s, str) for s in sub_names): + continue + sub_names_list = list(sub_names) + if not sub_names_list: + continue + + existing = mapping.get(packed_name) + if existing is None: + mapping[packed_name] = sub_names_list + elif existing != sub_names_list: + logger.warning( + "Conflicting packed_modules_mapping for %s: %s vs %s; using %s", + packed_name, + existing, + sub_names_list, + existing, + ) + + return mapping + + def _get_packed_sublayer_suffixes(self, packed_module_suffix: str, n_slices: int) -> list[str] | None: + sub_suffixes = self._packed_modules_mapping.get(packed_module_suffix) + if not sub_suffixes: + return None + if len(sub_suffixes) != n_slices: + logger.warning( + "packed_modules_mapping[%s] has %d slices but layer expects %d; skipping sublayer lookup", + packed_module_suffix, + len(sub_suffixes), + n_slices, + ) + return None + return sub_suffixes + + def set_active_adapter(self, lora_request: LoRARequest | None, lora_scale: float = 1.0) -> None: + """Set the active LoRA adapter for the pipeline. + + Args: + lora_request: The LoRA request, or None to deactivate all adapters. + lora_scale: The external scale for the LoRA adapter. + """ + if lora_request is None: + logger.debug("No lora_request provided, deactivating all LoRA adapters") + self._deactivate_all_adapters() + return + + adapter_id = lora_request.lora_int_id + logger.debug( + "Setting active adapter: id=%d, name=%s, path=%s, scale=%.2f, cache_size=%d/%d", + adapter_id, + lora_request.lora_name, + lora_request.lora_path, + lora_scale, + len(self._registered_adapters), + self.max_cached_adapters, + ) + if adapter_id not in self._registered_adapters: + logger.info("Loading new adapter: id=%d, name=%s", adapter_id, lora_request.lora_name) + self.add_adapter(lora_request, lora_scale) + else: + logger.debug("Adapter %d already loaded, activating", adapter_id) + + # update access order + self._adapter_scales[adapter_id] = lora_scale + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + + self._activate_adapter(adapter_id) + + def _load_adapter( + self, + lora_request: LoRARequest, + ) -> tuple[LoRAModel, PEFTHelper]: + if not self._expected_lora_modules: + raise ValueError("No supported LoRA modules found in the diffusion pipeline.") + + logger.debug("Supported LoRA modules: %s", self._expected_lora_modules) + + lora_path = get_adapter_absolute_path(lora_request.lora_path) + logger.debug("Resolved LoRA path: %s", lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, + max_position_embeddings=None, # no need in diffusion + tensorizer_config_dict=lora_request.tensorizer_config_dict, + ) + + logger.info( + "Loaded PEFT config: r=%d, lora_alpha=%d, target_modules=%s", + peft_helper.r, + peft_helper.lora_alpha, + peft_helper.target_modules, + ) + + lora_model = LoRAModel.from_local_checkpoint( + lora_path, + expected_lora_modules=self._expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", # consistent w/ vllm's behavior + dtype=self.dtype, + model_vocab_size=None, + tensorizer_config_dict=lora_request.tensorizer_config_dict, + weights_mapper=None, + ) + + logger.info( + "Loaded LoRA model: id=%d, num_modules=%d, modules=%s", + lora_model.id, + len(lora_model.loras), + list(lora_model.loras.keys()), + ) + + for lora in lora_model.loras.values(): + lora.optimize() # ref: _create_merged_loras_inplace, internal scaling + + return lora_model, peft_helper + + def _get_packed_modules_list(self, module: nn.Module) -> list[str]: + """Return a packed_modules_list suitable for vLLM LoRA can_replace_layer(). + + Diffusion transformers frequently use packed projection layers like + QKVParallelLinear (fused QKV). vLLM's LoRA replacement logic relies on + `packed_modules_list` length to decide between single-slice vs packed + LoRA layer implementations. + """ + if isinstance(module, QKVParallelLinear): + # Treat diffusion QKV as a 3-slice packed projection by default. + return ["q", "k", "v"] + if isinstance(module, MergedColumnParallelLinear): + # 2-slice packed projection (e.g. fused MLP projections). + return ["0", "1"] + return [] + + def _replace_layers_with_lora(self, peft_helper: PEFTHelper) -> None: + self._ensure_max_lora_rank(peft_helper.r) + + target_modules = getattr(peft_helper, "target_modules", None) + target_modules_list: list[str] | None = None + target_modules_pattern: str | None = None + if isinstance(target_modules, str) and target_modules: + target_modules_pattern = target_modules + elif isinstance(target_modules, list) and target_modules: + target_modules_list = target_modules + + def _matches_target(module_name: str) -> bool: + if target_modules_pattern is not None: + import regex as re + + return re.search(target_modules_pattern, module_name) is not None + if target_modules_list is None: + return True + return _match_target_modules(module_name, target_modules_list) + + # dummy lora config + lora_config = LoRAConfig( + max_lora_rank=self._max_lora_rank, + max_loras=1, + max_cpu_loras=self.max_cached_adapters, + lora_dtype=self.dtype, + fully_sharded_loras=False, + ) + + for component_name in ("transformer", "transformer_2", "dit"): + if not hasattr(self.pipeline, component_name): + continue + component = getattr(self.pipeline, component_name) + if not isinstance(component, nn.Module): + continue + + for module_name, module in component.named_modules(remove_duplicate=False): + # Don't recurse into already-replaced LoRA wrappers. Their + # original LinearBase lives under "base_layer", and replacing + # that again would nest LoRA wrappers and break execution. + if isinstance(module, BaseLayerWithLoRA) or "base_layer" in module_name.split("."): + continue + + full_module_name = f"{component_name}.{module_name}" + if full_module_name in self._lora_modules: + logger.debug("Layer %s already replaced, skipping", full_module_name) + continue + + packed_modules_list = self._get_packed_modules_list(module) + if target_modules_pattern is not None or target_modules_list is not None: + should_replace = _matches_target(full_module_name) + if not should_replace and len(packed_modules_list) > 1: + prefix, _, packed_suffix = full_module_name.rpartition(".") + sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, len(packed_modules_list)) + if sub_suffixes is not None: + for sub_suffix in sub_suffixes: + sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix + if _matches_target(sub_full_name): + should_replace = True + break + + if not should_replace: + continue + + lora_layer = from_layer_diffusion( + layer=module, + max_loras=1, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=None, + ) + + if lora_layer is not module and isinstance(lora_layer, BaseLayerWithLoRA): + replace_submodule(component, module_name, lora_layer) + self._lora_modules[full_module_name] = lora_layer + logger.debug("Replaced layer: %s -> %s", full_module_name, type(lora_layer).__name__) + + def _ensure_max_lora_rank(self, min_rank: int) -> None: + """Ensure LoRA buffers can accommodate adapters up to `min_rank`. + + We allocate per-layer LoRA buffers once when we first replace layers. + If a later adapter has a larger rank, we need to reinitialize those + buffers and re-apply the currently active adapter. + """ + if min_rank <= self._max_lora_rank: + return + + if min_rank <= 0: + raise ValueError(f"Invalid LoRA rank: {min_rank}") + + logger.info("Increasing max LoRA rank: %d -> %d", self._max_lora_rank, min_rank) + self._max_lora_rank = min_rank + + if not self._lora_modules: + return + + lora_config = LoRAConfig( + max_lora_rank=self._max_lora_rank, + max_loras=1, + max_cpu_loras=self.max_cached_adapters, + lora_dtype=self.dtype, + fully_sharded_loras=False, + ) + + # Recreate per-layer buffers with the new maximum rank. + for lora_layer in self._lora_modules.values(): + lora_layer.create_lora_weights(max_loras=1, lora_config=lora_config, model_config=None) + + # Re-apply active adapter if needed (buffers were reset). + if self._active_adapter_id is not None: + active_id = self._active_adapter_id + self._active_adapter_id = None + self._activate_adapter(active_id) + + def _get_lora_weights( + self, + lora_model: LoRAModel, + full_module_name: str, + ) -> LoRALayerWeights | PackedLoRALayerWeights | None: + """Best-effort lookup for LoRA weights by name. + + Tries: + - Full module name (e.g. transformer.blocks.0.attn.to_qkv) + - Relative name without the top-level component (e.g. blocks.0.attn.to_qkv) + - Suffix-only name (e.g. to_qkv) + """ + lora_weights = lora_model.get_lora(full_module_name) + if lora_weights is not None: + return lora_weights + + component_relative_name = full_module_name.split(".", 1)[-1] if "." in full_module_name else full_module_name + lora_weights = lora_model.get_lora(component_relative_name) + if lora_weights is not None: + return lora_weights + + module_suffix = full_module_name.split(".")[-1] + return lora_model.get_lora(module_suffix) + + def _activate_adapter(self, adapter_id: int) -> None: + if self._active_adapter_id == adapter_id: + logger.debug("Adapter %d already active, skipping", adapter_id) + return + + logger.info("Activating adapter: id=%d", adapter_id) + lora_model = self._registered_adapters[adapter_id] + + # activate weights in each LoRA layer + for full_module_name, lora_layer in self._lora_modules.items(): + lora_weights = self._get_lora_weights(lora_model, full_module_name) + + if lora_weights is None: + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + prefix, _, packed_suffix = full_module_name.rpartition(".") + sub_suffixes = self._get_packed_sublayer_suffixes(packed_suffix, n_slices) + if sub_suffixes is None: + lora_layer.reset_lora(0) + continue + + sub_loras: list[LoRALayerWeights | None] = [] + any_found = False + for sub_suffix in sub_suffixes: + sub_full_name = f"{prefix}.{sub_suffix}" if prefix else sub_suffix + sub_lora = self._get_lora_weights(lora_model, sub_full_name) + if sub_lora is not None: + any_found = True + # Packed layers expect plain (non-packed) subloras. + if isinstance(sub_lora, PackedLoRALayerWeights): + sub_lora = None + sub_loras.append(sub_lora if isinstance(sub_lora, LoRALayerWeights) else None) + + if not any_found: + lora_layer.reset_lora(0) + continue + + scale = self._adapter_scales.get(adapter_id, 1.0) + lora_a_list: list[torch.Tensor | None] = [] + lora_b_list: list[torch.Tensor | None] = [] + for sub_lora in sub_loras: + if sub_lora is None: + lora_a_list.append(None) + lora_b_list.append(None) + continue + lora_a_list.append(sub_lora.lora_a) + lora_b_list.append(sub_lora.lora_b * scale) + + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated packed LoRA for %s via submodules=%s (scale=%.2f)", + full_module_name, + sub_suffixes, + scale, + ) + else: + lora_layer.reset_lora(0) + continue + + scale = self._adapter_scales.get(adapter_id, 1.0) + + # Packed LoRA weights already provide per-slice tensors. + if isinstance(lora_weights, PackedLoRALayerWeights): + lora_a_list = lora_weights.lora_a + lora_b_list = [ + None if b is None else b * scale # type: ignore[operator] + for b in lora_weights.lora_b + ] + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated packed LoRA for %s (scale=%.2f)", + full_module_name, + scale, + ) + continue + + # Fused (non-packed) weights: if the layer is multi-slice, split B. + n_slices = getattr(lora_layer, "n_slices", 1) + if n_slices > 1: + output_slices = getattr(lora_layer, "output_slices", None) + if output_slices is None: + lora_layer.reset_lora(0) + continue + + total = sum(output_slices) + if lora_weights.lora_b.shape[0] != total: + logger.warning( + "Skipping LoRA for %s due to shape mismatch: lora_b[0]=%d != sum(output_slices)=%d", + full_module_name, + lora_weights.lora_b.shape[0], + total, + ) + lora_layer.reset_lora(0) + continue + + b_splits = list(torch.split(lora_weights.lora_b, list(output_slices), dim=0)) + lora_a_list = [lora_weights.lora_a] * n_slices + lora_b_list = [b * scale for b in b_splits] + lora_layer.set_lora(index=0, lora_a=lora_a_list, lora_b=lora_b_list) + logger.debug( + "Activated fused LoRA for packed layer %s (scale=%.2f)", + full_module_name, + scale, + ) + continue + + scaled_lora_b = lora_weights.lora_b * scale + lora_layer.set_lora(index=0, lora_a=lora_weights.lora_a, lora_b=scaled_lora_b) + logger.debug( + "Activated LoRA for %s: lora_a shape=%s, lora_b shape=%s, scale=%.2f", + full_module_name, + lora_weights.lora_a.shape, + lora_weights.lora_b.shape, + scale, + ) + + self._active_adapter_id = adapter_id + + def _deactivate_all_adapters(self) -> None: + logger.info("Deactivating all adapters: %d layers", len(self._lora_modules)) + for lora_layer in self._lora_modules.values(): + lora_layer.reset_lora(0) + self._active_adapter_id = None + logger.debug("All adapters deactivated") + + def _evict_if_needed(self) -> None: + while len(self._registered_adapters) > self.max_cached_adapters: + # Pick LRU among non-pinned adapters + evict_candidates = [aid for aid in self._adapter_access_order.keys() if aid not in self._pinned_adapters] + if not evict_candidates: + logger.warning( + "Cache full (%d) but all adapters are pinned; cannot evict. " + "Increase max_cached_adapters or unpin adapters.", + self.max_cached_adapters, + ) + break + + lru_adapter_id = evict_candidates[0] + logger.info( + "Evicting LRU adapter: id=%d (cache: %d/%d)", + lru_adapter_id, + len(self._registered_adapters), + self.max_cached_adapters, + ) + self.remove_adapter(lru_adapter_id) + + def add_adapter(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """ + Add a new adapter to the cache without activating it. + """ + adapter_id = lora_request.lora_int_id + + if adapter_id in self._registered_adapters: + logger.debug("Adapter %d already registered, skipping", adapter_id) + return False + + logger.info("Adding new adapter: id=%d, name=%s", adapter_id, lora_request.lora_name) + lora_model, peft_helper = self._load_adapter(lora_request) + self._registered_adapters[adapter_id] = lora_model + self._adapter_scales[adapter_id] = lora_scale + + self._replace_layers_with_lora(peft_helper) + + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + + # evict if cache full + self._evict_if_needed() + + logger.debug( + "Adapter %d added, cache size: %d/%d", adapter_id, len(self._registered_adapters), self.max_cached_adapters + ) + return True + + def remove_adapter(self, adapter_id: int) -> bool: + """ + Remove an adapter from the cache. + """ + if adapter_id not in self._registered_adapters: + logger.debug("Adapter %d not found, cannot remove", adapter_id) + return False + + logger.info("Removing adapter: id=%d", adapter_id) + if self._active_adapter_id == adapter_id: + self._deactivate_all_adapters() + + del self._registered_adapters[adapter_id] + self._adapter_scales.pop(adapter_id, None) + self._adapter_access_order.pop(adapter_id, None) + self._pinned_adapters.discard(adapter_id) + logger.debug( + "Adapter %d removed, cache size: %d/%d", + adapter_id, + len(self._registered_adapters), + self.max_cached_adapters, + ) + return True + + def list_adapters(self) -> list[int]: + """Return list of registered adapter ids.""" + return list(self._registered_adapters.keys()) + + def pin_adapter(self, adapter_id: int) -> bool: + """Mark an adapter as pinned so it will not be evicted.""" + if adapter_id not in self._registered_adapters: + logger.debug("Adapter %d not found, cannot pin", adapter_id) + return False + self._pinned_adapters.add(adapter_id) + # Touch access order so it is most recently used + self._adapter_access_order[adapter_id] = time.time() + self._adapter_access_order.move_to_end(adapter_id) + logger.info("Pinned adapter id=%d (won't be evicted)", adapter_id) + return True diff --git a/vllm_omni/diffusion/lora/utils.py b/vllm_omni/diffusion/lora/utils.py new file mode 100644 index 00000000000..5f1baea34df --- /dev/null +++ b/vllm_omni/diffusion/lora/utils.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm_omni.config.lora import LoRAConfig +from vllm_omni.diffusion.lora.layers import ( + DiffusionColumnParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, + DiffusionReplicatedLinearWithLoRA, + DiffusionRowParallelLinearWithLoRA, +) + + +def _match_target_modules(module_name: str, target_modules: list[str]) -> bool: + """from vllm/lora/model_manager.py _match_target_modules, helper function""" + import regex as re + + return any( + re.match(rf".*\.{target_module}$", module_name) or target_module == module_name + for target_module in target_modules + ) + + +def _expand_expected_modules_for_packed_layers( + supported_modules: set[str], + packed_modules_mapping: dict[str, list[str]] | None, +) -> set[str]: + """Expand expected LoRA module suffixes for packed (fused) projections. + + Some diffusion models use packed projections like `to_qkv` or `w13`, while + LoRA checkpoints are typically saved against the logical sub-projections + (e.g. `to_q`/`to_k`/`to_v`, `w1`/`w3`). The packed layer name is present in + `supported_modules`, but the sublayer names are not. Expanding the set + ensures these sublayer keys are not dropped when loading a LoRA checkpoint. + + The packed→sublayer mapping is model-specific (see each diffusion model's + `packed_modules_mapping`) so new packed layers are added alongside the model + implementation rather than hard-coded in the LoRA framework. + """ + expanded = set(supported_modules) + if not packed_modules_mapping: + return expanded + + for packed_name, sub_names in packed_modules_mapping.items(): + if packed_name in supported_modules: + expanded.update(sub_names) + + return expanded + + +def from_layer_diffusion( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: list[str], + model_config: PretrainedConfig | None = None, +) -> nn.Module: + """ + Diffusion-specific layer replacement. similar to vLLM's `from_layer` + """ + diffusion_lora_classes = [ + DiffusionMergedQKVParallelLinearWithLoRA, + DiffusionQKVParallelLinearWithLoRA, + DiffusionMergedColumnParallelLinearWithLoRA, + DiffusionColumnParallelLinearWithLoRA, + DiffusionRowParallelLinearWithLoRA, + DiffusionReplicatedLinearWithLoRA, + ] + + for lora_cls in diffusion_lora_classes: + if lora_cls.can_replace_layer( + source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + ): + instance = lora_cls(layer) # type: ignore[arg-type] + instance.create_lora_weights(max_loras, lora_config, model_config) + return instance + + return layer diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 0398bfbc26c..892954ce9ea 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -94,6 +94,7 @@ def _prepare_weights( load_format = self.load_config.load_format use_safetensors = False index_file = DIFFUSION_MODEL_WEIGHTS_INDEX + index_file_with_subfolder = f"{subfolder}/{index_file}" if subfolder else index_file # only hf is supported currently if load_format == "auto": @@ -129,8 +130,8 @@ def _prepare_weights( for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True + # Decide by actual files rather than pattern name (patterns may include subfolders). + use_safetensors = any(f.endswith(".safetensors") for f in hf_weights_files) break if use_safetensors: @@ -142,11 +143,22 @@ def _prepare_weights( if not is_local: download_safetensors_index_file_from_hf( model_name_or_path, - index_file, + index_file_with_subfolder, self.load_config.download_dir, revision, ) - hf_weights_files = filter_duplicate_safetensors_files(hf_weights_files, hf_folder, index_file) + # Some diffusers pipelines keep component weights under a + # subfolder (e.g. "transformer/") and the corresponding index file + # uses filenames relative to that subfolder. vLLM's + # `filter_duplicate_safetensors_files` expects weight_map entries + # to be relative to the `hf_folder` we pass in, so we point it to + # the component subfolder to avoid filtering out all shards. + filter_folder = os.path.join(hf_folder, subfolder) if subfolder is not None else hf_folder + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, + filter_folder, + index_file, + ) else: hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) @@ -188,8 +200,9 @@ def get_all_weights( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( - model_config.model, - model_config.revision, + model_name_or_path=model_config.model, + subfolder=None, + revision=model_config.revision, fall_back_to_pt=True, allow_patterns_overrides=None, ) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index 86658a01deb..312ac9b6a2e 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -518,6 +518,10 @@ class Flux2Transformer2DModel(nn.Module): """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } def __init__( self, diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py index 615b9194af4..c85c0bfd9e0 100644 --- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -551,6 +551,10 @@ class GlmImageTransformer2DModel(CachedTransformer): are read from `od_config.tf_model_config`. """ + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + } + def __init__( self, od_config: OmniDiffusionConfig, diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index d2a55046df5..6bb282a80f4 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -503,6 +503,11 @@ class LongCatImageTransformer2DModel(nn.Module): Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. """ + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } + def __init__( self, od_config: OmniDiffusionConfig, diff --git a/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py index 06bbf570666..ae6cf6b0ccf 100644 --- a/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py +++ b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py @@ -366,6 +366,10 @@ class OvisImageTransformer2DModel(nn.Module): """ _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } def __init__( self, diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 14905e2a895..04890fbb93a 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -782,6 +782,10 @@ class QwenImageTransformer2DModel(CachedTransformer): # -- typically a transformer layer # used for torch compile optimizations _repeated_blocks = ["QwenImageTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } # Sequence Parallelism plan (following diffusers' _cp_plan pattern) # Similar to Z-Image's UnifiedPrepare, we use ImageRopePrepare to create diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py index 22a11741a53..e60bcbe5a14 100644 --- a/vllm_omni/diffusion/models/sd3/sd3_transformer.py +++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py @@ -322,6 +322,10 @@ class SD3Transformer2DModel(nn.Module): """ _repeated_blocks = ["SD3TransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } def __init__( self, diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index a1e2d789bc6..f56833e488d 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -516,6 +516,9 @@ class WanTransformer3DModel(nn.Module): """ _repeated_blocks = ["WanTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + } def __init__( self, diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 247d914e950..2b7d4eb5b4f 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -544,6 +544,10 @@ class ZImageTransformer2DModel(CachedTransformer): """ _repeated_blocks = ["ZImageTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "w13": ["w1", "w3"], + } # Sequence Parallelism for Z-Image (following diffusers' _cp_plan pattern) # Similar to how Wan uses `rope` module's split_output to shard rotary embeddings, diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py index 279b3bf1dfc..89c0a79f146 100644 --- a/vllm_omni/diffusion/request.py +++ b/vllm_omni/diffusion/request.py @@ -9,6 +9,8 @@ import PIL.Image import torch +from vllm_omni.lora.request import LoRARequest + @dataclass class OmniDiffusionRequest: @@ -142,6 +144,10 @@ class OmniDiffusionRequest: save_output: bool = True return_frames: bool = False + # LoRA + lora_request: LoRARequest | None = None + lora_scale: float = 1.0 + # STA parameters STA_param: list | None = None is_cfg_negative: bool = False diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py index 90d169a0e08..cadb6b64dc0 100644 --- a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py +++ b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py @@ -28,9 +28,11 @@ initialize_model_parallel, ) from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.profiler import CurrentProfiler from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner +from vllm_omni.lora.request import LoRARequest logger = init_logger(__name__) @@ -61,6 +63,7 @@ def __init__( self.vllm_config: VllmConfig | None = None self.model_runner: GPUDiffusionModelRunner | None = None self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + self.lora_manager: DiffusionLoRAManager | None = None self.init_device() def init_device(self) -> None: @@ -110,6 +113,15 @@ def init_device(self) -> None: self.model_runner.load_model( memory_pool_context_fn=self._maybe_get_memory_pool_context, ) + assert self.model_runner.pipeline is not None + self.lora_manager = DiffusionLoRAManager( + pipeline=self.model_runner.pipeline, + device=self.device, + dtype=self.od_config.dtype, + max_cached_adapters=self.od_config.max_cpu_loras, + lora_path=self.od_config.lora_path, + lora_scale=self.od_config.lora_scale, + ) logger.info(f"Worker {self.rank}: Initialization complete.") def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: @@ -129,6 +141,33 @@ def stop_profile(cls) -> dict | None: def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput: """Execute a forward pass by delegating to the model runner.""" assert self.model_runner is not None, "Model runner not initialized" + if self.lora_manager is not None and reqs: + req = reqs[0] + + if len(reqs) > 1: + # This worker (and the current diffusion model runner) applies + # a single LoRA to the whole batch. Reject inconsistent LoRA + # settings to avoid silently applying the wrong adapter. + def _lora_key(r: OmniDiffusionRequest): + if r.lora_request is None: + return None + lr = r.lora_request + return (lr.lora_name, lr.lora_int_id, lr.lora_path, lr.tensorizer_config_dict) + + key0 = _lora_key(req) + scale0 = req.lora_scale if key0 is not None else None + for other in reqs[1:]: + if _lora_key(other) != key0: + raise ValueError("All requests in a diffusion batch must share the same LoRARequest.") + if key0 is not None and other.lora_scale != scale0: + raise ValueError("All requests in a diffusion batch must share the same lora_scale.") + + try: + self.lora_manager.set_active_adapter(req.lora_request, req.lora_scale) + except Exception as exc: + if req.lora_request is not None: + raise + logger.warning("LoRA activation skipped: %s", exc) return self.model_runner.execute_model(reqs) def load_weights(self, weights) -> set[str]: @@ -136,6 +175,18 @@ def load_weights(self, weights) -> set[str]: assert self.model_runner is not None, "Model runner not initialized" return self.model_runner.load_weights(weights) + def remove_lora(self, adapter_id: int) -> bool: + return self.lora_manager.remove_adapter(adapter_id) + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + return self.lora_manager.add_adapter(lora_request, lora_scale) + + def list_loras(self) -> list[int]: + return self.lora_manager.list_adapters() + + def pin_lora(self, adapter_id: int) -> bool: + return self.lora_manager.pin_adapter(adapter_id) + def sleep(self, level: int = 1) -> bool: """ Put the worker to sleep, offloading model weights. diff --git a/vllm_omni/engine/input_processor.py b/vllm_omni/engine/input_processor.py index eb81f38dc66..91e12f6a6a6 100644 --- a/vllm_omni/engine/input_processor.py +++ b/vllm_omni/engine/input_processor.py @@ -7,7 +7,6 @@ from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs.parse import split_enc_dec_inputs from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.utils import argsort_mm_positions @@ -25,6 +24,7 @@ PromptEmbedsPayload, ) from vllm_omni.inputs.preprocess import OmniInputPreprocessor +from vllm_omni.lora.request import LoRARequest logger = init_logger(__name__) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3c275147fa0..46e0a243d3b 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -11,13 +11,11 @@ from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.plugins.io_processors import get_io_processor from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike from vllm.v1.engine.exceptions import EngineDeadError -# Internal imports (our code) from vllm_omni.config import OmniModelConfig from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector @@ -34,6 +32,9 @@ from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, ) + +# Internal imports (our code) +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 1d47205f172..d29e71e3422 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -22,6 +22,7 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -151,6 +152,7 @@ async def generate( negative_prompt: str | None = None, num_outputs_per_prompt: int = 1, seed: int | None = None, + lora_request=None, **kwargs: Any, ) -> OmniRequestOutput: """Generate images asynchronously from a text prompt. @@ -186,6 +188,7 @@ async def generate( "negative_prompt": negative_prompt, "num_outputs_per_prompt": num_outputs_per_prompt, "seed": seed, + "lora_request": lora_request, **kwargs, } if guidance_scale is not None: @@ -302,3 +305,65 @@ def is_running(self) -> bool: def is_stopped(self) -> bool: """Check if the engine is stopped.""" return self._closed + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "remove_lora", + None, + (adapter_id,), + {}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Add a LoRA adapter""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "add_lora", + None, + (), + {"lora_request": lora_request, "lora_scale": lora_scale}, + None, + ) + return all(results) if isinstance(results, list) else results + + async def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "list_loras", + None, + (), + {}, + None, + ) + # collective_rpc returns list from workers; flatten unique ids + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + async def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted.""" + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "pin_lora", + None, + (), + {"adapter_id": lora_id}, + None, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 3d918aea79c..d063938a5f5 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -221,6 +221,26 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: self.config_path = stage_configs_path self.stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args) + # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. + for cfg in self.stage_configs: + try: + if getattr(cfg, "stage_type", None) != "diffusion": + continue + if not hasattr(cfg, "engine_args") or cfg.engine_args is None: + cfg.engine_args = OmegaConf.create({}) + if kwargs.get("lora_path") is not None: + if not hasattr(cfg.engine_args, "lora_path") or cfg.engine_args.lora_path is None: + cfg.engine_args.lora_path = kwargs["lora_path"] + lora_scale = kwargs.get("lora_scale") + if lora_scale is None: + # Backwards compatibility for older callers. + lora_scale = kwargs.get("static_lora_scale") + if lora_scale is not None: + if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: + cfg.engine_args.lora_scale = lora_scale + except Exception as e: + logger.warning("Failed to inject LoRA config for stage: %s", e) + # Initialize connectors self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors( self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 2369a10001e..3c44da8b32b 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -33,7 +33,14 @@ ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, + ModelCard, + ModelList, + ModelPermission, ) + +# yapf conflicts with isort for this block +# yapf: disable +# yapf: enable from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses @@ -73,6 +80,8 @@ ) from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id logger = init_logger(__name__) @@ -100,6 +109,31 @@ def _remove_route_from_router(router_obj, path: str, methods: set[str] | None = _remove_route_from_router(router, "/v1/chat/completions", {"POST"}) +class _DiffusionServingModels: + """Minimal OpenAIServingModels implementation for diffusion-only servers. + + vLLM's /v1/models route expects `app.state.openai_serving_models` to expose + `show_available_models()`. In pure diffusion mode we don't initialize the + full OpenAIServingModels (it depends on LLM-specific processors), so we + provide a lightweight fallback. + """ + + def __init__(self, base_model_paths: list[BaseModelPath]) -> None: + self._base_model_paths = base_model_paths + + async def show_available_models(self) -> ModelList: + return ModelList( + data=[ + ModelCard( + id=base_model.name, + root=base_model.model_path, + permission=[ModelPermission()], + ) + for base_model in self._base_model_paths + ] + ) + + # Server entry points @@ -330,6 +364,7 @@ async def omni_init_app_state( model_name = served_model_names[0] if served_model_names else args.model state.vllm_config = None state.diffusion_engine = engine_client + state.openai_serving_models = _DiffusionServingModels(base_model_paths) # Use for_diffusion method to create chat handler state.openai_serving_chat = OmniOpenAIServingChat.for_diffusion( @@ -823,6 +858,40 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) "num_outputs_per_prompt": request.n, } + # Parse per-request LoRA (compatible with chat's extra_body.lora shape). + if request.lora is not None: + if not isinstance(request.lora, dict): + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Invalid lora field: expected an object.", + ) + lora_body = request.lora + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + + if not lora_name or not lora_path: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Invalid lora object: both name and path are required.", + ) + + gen_params["lora_request"] = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + if lora_scale is not None: + gen_params["lora_scale"] = float(lora_scale) + # Parse and add size if provided if request.size: width, height = parse_size(request.size) diff --git a/vllm_omni/entrypoints/openai/protocol/images.py b/vllm_omni/entrypoints/openai/protocol/images.py index cb7c346ac76..093ecaddf93 100644 --- a/vllm_omni/entrypoints/openai/protocol/images.py +++ b/vllm_omni/entrypoints/openai/protocol/images.py @@ -8,6 +8,7 @@ """ from enum import Enum +from typing import Any from pydantic import BaseModel, Field, field_validator @@ -88,6 +89,18 @@ def validate_response_format(cls, v): ) seed: int | None = Field(default=None, description="Random seed for reproducibility") + # vllm-omni extension for per-request LoRA. + # This mirrors the `extra_body.lora` convention in /v1/chat/completions. + lora: dict[str, Any] | None = Field( + default=None, + description=( + "Optional LoRA adapter for this request. Expected shape: " + "{name/path/scale/int_id}. Field names are flexible " + "(e.g. name|lora_name|adapter, path|lora_path|local_path, " + "scale|lora_scale, int_id|lora_int_id)." + ), + ) + # VAE memory optimizations (set at model init, included for completeness) vae_use_slicing: bool | None = Field(default=False, description="Enable VAE slicing") vae_use_tiling: bool | None = Field(default=False, description="Enable VAE tiling") diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 8a62fd16c9f..179a0c29052 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -63,7 +63,6 @@ from vllm.entrypoints.utils import should_include_usage from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.tokenizers import TokenizerLike @@ -82,6 +81,8 @@ from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio +from vllm_omni.lora.request import LoRARequest +from vllm_omni.lora.utils import stable_lora_int_id from vllm_omni.outputs import OmniRequestOutput if TYPE_CHECKING: @@ -1878,6 +1879,7 @@ async def _create_diffusion_chat_completion( # Text-to-video parameters (ref: text_to_video.py) num_frames = extra_body.get("num_frames") guidance_scale_2 = extra_body.get("guidance_scale_2") # For video high-noise CFG + lora_body = extra_body.get("lora") logger.info( "Diffusion chat request %s: prompt=%r, ref_images=%d, params=%s", @@ -1921,6 +1923,33 @@ async def _create_diffusion_chat_completion( if guidance_scale_2 is not None: gen_kwargs["guidance_scale_2"] = guidance_scale_2 + # Parse per-request LoRA (works for both AsyncOmniDiffusion and AsyncOmni). + if lora_body and isinstance(lora_body, dict): + try: + lora_name = lora_body.get("name") or lora_body.get("lora_name") or lora_body.get("adapter") + lora_path = ( + lora_body.get("local_path") + or lora_body.get("path") + or lora_body.get("lora_path") + or lora_body.get("lora_local_path") + ) + # using "or" directly here may be buggy if `scale=0` + lora_scale = lora_body.get("scale") + if lora_scale is None: + lora_scale = lora_body.get("lora_scale") + lora_int_id = lora_body.get("int_id") + if lora_int_id is None: + lora_int_id = lora_body.get("lora_int_id") + if lora_int_id is None and lora_path: + lora_int_id = stable_lora_int_id(str(lora_path)) + if lora_name and lora_path: + lora_req = LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)) + gen_kwargs["lora_request"] = lora_req + if lora_scale is not None: + gen_kwargs["lora_scale"] = float(lora_scale) + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + # Add reference image if provided if pil_images: if len(pil_images) == 1: diff --git a/vllm_omni/lora/__init__.py b/vllm_omni/lora/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/lora/request.py b/vllm_omni/lora/request.py new file mode 100644 index 00000000000..55eb02ba447 --- /dev/null +++ b/vllm_omni/lora/request.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# for now, it suffices to use vLLM's implementation directly +# as this is a user-facing variable, defined here to so that user can directly import LoRARequest from vllm_omni +from vllm.lora.request import LoRARequest + +__all__ = ["LoRARequest"] diff --git a/vllm_omni/lora/utils.py b/vllm_omni/lora/utils.py new file mode 100644 index 00000000000..9404d080f6c --- /dev/null +++ b/vllm_omni/lora/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import hashlib + + +def stable_lora_int_id(lora_path: str) -> int: + """Return a deterministic positive integer ID for a LoRA adapter. + + vLLM uses `lora_int_id` as the adapter's cache key. Python's built-in + `hash()` is intentionally randomized per process (PYTHONHASHSEED), which + makes it unsuitable for persistent IDs. This helper derives a stable + 63-bit positive integer from the adapter path. + """ + digest = hashlib.sha256(lora_path.encode("utf-8")).digest() + value = int.from_bytes(digest[:8], byteorder="big", signed=False) & ((1 << 63) - 1) + return value or 1 + + +__all__ = ["stable_lora_int_id"]