diff --git a/backend/python/common/libbackend.sh b/backend/python/common/libbackend.sh index b4bcf578fc73..9af6ca6736f5 100644 --- a/backend/python/common/libbackend.sh +++ b/backend/python/common/libbackend.sh @@ -237,7 +237,14 @@ function getBuildProfile() { # Make the venv relocatable: # - rewrite venv/bin/python{,3} to relative symlinks into $(_portable_dir) # - normalize entrypoint shebangs to /usr/bin/env python3 +# - optionally update pyvenv.cfg to point to the portable Python directory (only at runtime) +# Usage: _makeVenvPortable [--update-pyvenv-cfg] _makeVenvPortable() { + local update_pyvenv_cfg=false + if [ "${1:-}" = "--update-pyvenv-cfg" ]; then + update_pyvenv_cfg=true + fi + local venv_dir="${EDIR}/venv" local vbin="${venv_dir}/bin" @@ -255,7 +262,39 @@ _makeVenvPortable() { ln -s "${rel_py}" "${vbin}/python3" ln -s "python3" "${vbin}/python" - # 2) Rewrite shebangs of entry points to use env, so the venv is relocatable + # 2) Update pyvenv.cfg to point to the portable Python directory (only at runtime) + # Use absolute path resolved at runtime so it works when the venv is copied + if [ "$update_pyvenv_cfg" = "true" ]; then + local pyvenv_cfg="${venv_dir}/pyvenv.cfg" + if [ -f "${pyvenv_cfg}" ]; then + local portable_dir="$(_portable_dir)" + # Resolve to absolute path - this ensures it works when the backend is copied + # Only resolve if the directory exists (it should if ensurePortablePython was called) + if [ -d "${portable_dir}" ]; then + portable_dir="$(cd "${portable_dir}" && pwd)" + else + # Fallback to relative path if directory doesn't exist yet + portable_dir="../python" + fi + local sed_i=(sed -i) + # macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable: + if sed --version >/dev/null 2>&1; then + sed_i=(sed -i) + else + sed_i=(sed -i '') + fi + # Update the home field in pyvenv.cfg + # Handle both absolute paths (starting with /) and relative paths + if grep -q "^home = " "${pyvenv_cfg}"; then + "${sed_i[@]}" "s|^home = .*|home = ${portable_dir}|" "${pyvenv_cfg}" + else + # If home field doesn't exist, add it + echo "home = ${portable_dir}" >> "${pyvenv_cfg}" + fi + fi + fi + + # 3) Rewrite shebangs of entry points to use env, so the venv is relocatable # Only touch text files that start with #! and reference the current venv. local ve_abs="${vbin}/python" local sed_i=(sed -i) @@ -316,6 +355,7 @@ function ensureVenv() { fi fi if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then + # During install, only update symlinks and shebangs, not pyvenv.cfg _makeVenvPortable fi fi @@ -420,6 +460,11 @@ function installRequirements() { # - ${BACKEND_NAME}.py function startBackend() { ensureVenv + # Update pyvenv.cfg before running to ensure paths are correct for current location + # This is critical when the backend position is dynamic (e.g., copied from container) + if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -x "$(_portable_python)" ]; then + _makeVenvPortable --update-pyvenv-cfg + fi if [ ! -z "${BACKEND_FILE:-}" ]; then exec "${EDIR}/venv/bin/python" "${BACKEND_FILE}" "$@" elif [ -e "${MY_DIR}/server.py" ]; then diff --git a/backend/python/diffusers/README.md b/backend/python/diffusers/README.md index f91beef69369..91fff3127694 100644 --- a/backend/python/diffusers/README.md +++ b/backend/python/diffusers/README.md @@ -1,5 +1,136 @@ -# Creating a separate environment for the diffusers project +# LocalAI Diffusers Backend + +This backend provides gRPC access to Hugging Face diffusers pipelines with dynamic pipeline loading. + +## Creating a separate environment for the diffusers project ``` make diffusers -``` \ No newline at end of file +``` + +## Dynamic Pipeline Loader + +The diffusers backend includes a dynamic pipeline loader (`diffusers_dynamic_loader.py`) that automatically discovers and loads diffusers pipelines at runtime. This eliminates the need for per-pipeline conditional statements - new pipelines added to diffusers become available automatically without code changes. + +### How It Works + +1. **Pipeline Discovery**: On first use, the loader scans the `diffusers` package to find all classes that inherit from `DiffusionPipeline`. + +2. **Registry Caching**: Discovery results are cached for the lifetime of the process to avoid repeated scanning. + +3. **Task Aliases**: The loader automatically derives task aliases from class names (e.g., "text-to-image", "image-to-image", "inpainting") without hardcoding. + +4. **Multiple Resolution Methods**: Pipelines can be resolved by: + - Exact class name (e.g., `StableDiffusionPipeline`) + - Task alias (e.g., `text-to-image`, `img2img`) + - Model ID (uses HuggingFace Hub to infer pipeline type) + +### Usage Examples + +```python +from diffusers_dynamic_loader import ( + load_diffusers_pipeline, + get_available_pipelines, + get_available_tasks, + resolve_pipeline_class, + discover_diffusers_classes, + get_available_classes, +) + +# List all available pipelines +pipelines = get_available_pipelines() +print(f"Available pipelines: {pipelines[:10]}...") + +# List all task aliases +tasks = get_available_tasks() +print(f"Available tasks: {tasks}") + +# Resolve a pipeline class by name +cls = resolve_pipeline_class(class_name="StableDiffusionPipeline") + +# Resolve by task alias +cls = resolve_pipeline_class(task="stable-diffusion") + +# Load and instantiate a pipeline +pipe = load_diffusers_pipeline( + class_name="StableDiffusionPipeline", + model_id="runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16 +) + +# Load from single file +pipe = load_diffusers_pipeline( + class_name="StableDiffusionPipeline", + model_id="/path/to/model.safetensors", + from_single_file=True, + torch_dtype=torch.float16 +) + +# Discover other diffusers classes (schedulers, models, etc.) +schedulers = discover_diffusers_classes("SchedulerMixin") +print(f"Available schedulers: {list(schedulers.keys())[:5]}...") + +# Get list of available scheduler classes +scheduler_list = get_available_classes("SchedulerMixin") +``` + +### Generic Class Discovery + +The dynamic loader can discover not just pipelines but any class type from diffusers: + +```python +# Discover all scheduler classes +schedulers = discover_diffusers_classes("SchedulerMixin") + +# Discover all model classes +models = discover_diffusers_classes("ModelMixin") + +# Get a sorted list of available classes +scheduler_names = get_available_classes("SchedulerMixin") +``` + +### Special Pipeline Handling + +Most pipelines are loaded dynamically through `load_diffusers_pipeline()`. Only pipelines requiring truly custom initialization logic are handled explicitly: + +- `FluxTransformer2DModel`: Requires quantization and custom transformer loading (cannot use dynamic loader) +- `WanPipeline` / `WanImageToVideoPipeline`: Uses dynamic loader with special VAE (float32 dtype) +- `SanaPipeline`: Uses dynamic loader with post-load dtype conversion for VAE/text encoder +- `StableVideoDiffusionPipeline`: Uses dynamic loader with CPU offload handling +- `VideoDiffusionPipeline`: Alias for DiffusionPipeline with video flags + +All other pipelines (StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, etc.) are loaded purely through the dynamic loader. + +### Error Handling + +When a pipeline cannot be resolved, the loader provides helpful error messages listing available pipelines and tasks: + +``` +ValueError: Unknown pipeline class 'NonExistentPipeline'. +Available pipelines: AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline, ... +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `COMPEL` | `0` | Enable Compel for prompt weighting | +| `XPU` | `0` | Enable Intel XPU support | +| `CLIPSKIP` | `1` | Enable CLIP skip support | +| `SAFETENSORS` | `1` | Use safetensors format | +| `CHUNK_SIZE` | `8` | Decode chunk size for video | +| `FPS` | `7` | Video frames per second | +| `DISABLE_CPU_OFFLOAD` | `0` | Disable CPU offload | +| `FRAMES` | `64` | Number of video frames | +| `BFL_REPO` | `ChuckMcSneed/FLUX.1-dev` | Flux base repo | +| `PYTHON_GRPC_MAX_WORKERS` | `1` | Max gRPC workers | + +## Running Tests + +```bash +./test.sh +``` + +The test suite includes: +- Unit tests for the dynamic loader (`test_dynamic_loader.py`) +- Integration tests for the gRPC backend (`test.py`) \ No newline at end of file diff --git a/backend/python/diffusers/backend.py b/backend/python/diffusers/backend.py index 026e3284a72b..cc7ff2288cee 100755 --- a/backend/python/diffusers/backend.py +++ b/backend/python/diffusers/backend.py @@ -1,4 +1,10 @@ #!/usr/bin/env python3 +""" +LocalAI Diffusers Backend + +This backend provides gRPC access to diffusers pipelines with dynamic pipeline loading. +New pipelines added to diffusers become available automatically without code changes. +""" from concurrent import futures import traceback import argparse @@ -17,14 +23,22 @@ import grpc -from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \ - EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline -from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline +# Import dynamic loader for pipeline discovery +from diffusers_dynamic_loader import ( + get_pipeline_registry, + resolve_pipeline_class, + get_available_pipelines, + load_diffusers_pipeline, +) + +# Import specific items still needed for special cases and safety checker +from diffusers import DiffusionPipeline, ControlNetModel +from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKLWan from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.utils import load_image, export_to_video from compel import Compel, ReturnedEmbeddingsType from optimum.quanto import freeze, qfloat8, quantize -from transformers import CLIPTextModel, T5EncoderModel +from transformers import T5EncoderModel from safetensors.torch import load_file _ONE_DAY_IN_SECONDS = 60 * 60 * 24 @@ -158,6 +172,165 @@ def get_scheduler(name: str, config: dict = {}): # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): + + def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant): + """ + Load a diffusers pipeline dynamically using the dynamic loader. + + This method uses load_diffusers_pipeline() for most pipelines, falling back + to explicit handling only for pipelines requiring custom initialization + (e.g., quantization, special VAE handling). + + Args: + request: The gRPC request containing pipeline configuration + modelFile: Path to the model file (for single file loading) + fromSingleFile: Whether to use from_single_file() vs from_pretrained() + torchType: The torch dtype to use + variant: Model variant (e.g., "fp16") + + Returns: + The loaded pipeline instance + """ + pipeline_type = request.PipelineType + + # Handle IMG2IMG request flag with default pipeline + if request.IMG2IMG and pipeline_type == "": + pipeline_type = "StableDiffusionImg2ImgPipeline" + + # ================================================================ + # Special cases requiring custom initialization logic + # Only handle pipelines that truly need custom code (quantization, + # special VAE handling, etc.). All other pipelines use dynamic loading. + # ================================================================ + + # FluxTransformer2DModel - requires quantization and custom transformer loading + if pipeline_type == "FluxTransformer2DModel": + dtype = torch.bfloat16 + bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev") + + transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype) + quantize(transformer, weights=qfloat8) + freeze(transformer) + text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + + pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype) + pipe.transformer = transformer + pipe.text_encoder_2 = text_encoder_2 + + if request.LowVRAM: + pipe.enable_model_cpu_offload() + return pipe + + # WanPipeline - requires special VAE with float32 dtype + if pipeline_type == "WanPipeline": + vae = AutoencoderKLWan.from_pretrained( + request.Model, + subfolder="vae", + torch_dtype=torch.float32 + ) + pipe = load_diffusers_pipeline( + class_name="WanPipeline", + model_id=request.Model, + vae=vae, + torch_dtype=torchType + ) + self.txt2vid = True + return pipe + + # WanImageToVideoPipeline - requires special VAE with float32 dtype + if pipeline_type == "WanImageToVideoPipeline": + vae = AutoencoderKLWan.from_pretrained( + request.Model, + subfolder="vae", + torch_dtype=torch.float32 + ) + pipe = load_diffusers_pipeline( + class_name="WanImageToVideoPipeline", + model_id=request.Model, + vae=vae, + torch_dtype=torchType + ) + self.img2vid = True + return pipe + + # SanaPipeline - requires special VAE and text encoder dtype conversion + if pipeline_type == "SanaPipeline": + pipe = load_diffusers_pipeline( + class_name="SanaPipeline", + model_id=request.Model, + variant="bf16", + torch_dtype=torch.bfloat16 + ) + pipe.vae.to(torch.bfloat16) + pipe.text_encoder.to(torch.bfloat16) + return pipe + + # VideoDiffusionPipeline - alias for DiffusionPipeline with txt2vid flag + if pipeline_type == "VideoDiffusionPipeline": + self.txt2vid = True + pipe = load_diffusers_pipeline( + class_name="DiffusionPipeline", + model_id=request.Model, + torch_dtype=torchType + ) + return pipe + + # StableVideoDiffusionPipeline - needs img2vid flag and CPU offload + if pipeline_type == "StableVideoDiffusionPipeline": + self.img2vid = True + pipe = load_diffusers_pipeline( + class_name="StableVideoDiffusionPipeline", + model_id=request.Model, + torch_dtype=torchType, + variant=variant + ) + if not DISABLE_CPU_OFFLOAD: + pipe.enable_model_cpu_offload() + return pipe + + # ================================================================ + # Dynamic pipeline loading - the default path for most pipelines + # Uses the dynamic loader to instantiate any pipeline by class name + # ================================================================ + + # Build kwargs for dynamic loading + load_kwargs = {"torch_dtype": torchType} + + # Add variant if not loading from single file + if not fromSingleFile and variant: + load_kwargs["variant"] = variant + + # Add use_safetensors for from_pretrained + if not fromSingleFile: + load_kwargs["use_safetensors"] = SAFETENSORS + + # Determine pipeline class name - default to AutoPipelineForText2Image + effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image" + + # Use dynamic loader for all pipelines + try: + pipe = load_diffusers_pipeline( + class_name=effective_pipeline_type, + model_id=modelFile if fromSingleFile else request.Model, + from_single_file=fromSingleFile, + **load_kwargs + ) + except Exception as e: + # Provide helpful error with available pipelines + available = get_available_pipelines() + raise ValueError( + f"Failed to load pipeline '{effective_pipeline_type}': {e}\n" + f"Available pipelines: {', '.join(available[:30])}..." + ) from e + + # Apply LowVRAM optimization if supported and requested + if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'): + pipe.enable_model_cpu_offload() + + return pipe + def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) @@ -231,139 +404,16 @@ def LoadModel(self, request, context): fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local self.img2vid = False self.txt2vid = False - ## img2img - if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""): - if fromSingleFile: - self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile, - torch_dtype=torchType) - else: - self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - - elif request.PipelineType == "StableDiffusionDepth2ImgPipeline": - self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - ## img2vid - elif request.PipelineType == "StableVideoDiffusionPipeline": - self.img2vid = True - self.pipe = StableVideoDiffusionPipeline.from_pretrained( - request.Model, torch_dtype=torchType, variant=variant - ) - if not DISABLE_CPU_OFFLOAD: - self.pipe.enable_model_cpu_offload() - ## text2img - elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "": - self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model, - torch_dtype=torchType, - use_safetensors=SAFETENSORS, - variant=variant) - elif request.PipelineType == "StableDiffusionPipeline": - if fromSingleFile: - self.pipe = StableDiffusionPipeline.from_single_file(modelFile, - torch_dtype=torchType) - else: - self.pipe = StableDiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - elif request.PipelineType == "DiffusionPipeline": - self.pipe = DiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - elif request.PipelineType == "QwenImageEditPipeline": - self.pipe = QwenImageEditPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - elif request.PipelineType == "VideoDiffusionPipeline": - self.txt2vid = True - self.pipe = DiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType) - elif request.PipelineType == "StableDiffusionXLPipeline": - if fromSingleFile: - self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile, - torch_dtype=torchType, - use_safetensors=True) - else: - self.pipe = StableDiffusionXLPipeline.from_pretrained( - request.Model, - torch_dtype=torchType, - use_safetensors=True, - variant=variant) - elif request.PipelineType == "StableDiffusion3Pipeline": - if fromSingleFile: - self.pipe = StableDiffusion3Pipeline.from_single_file(modelFile, - torch_dtype=torchType, - use_safetensors=True) - else: - self.pipe = StableDiffusion3Pipeline.from_pretrained( - request.Model, - torch_dtype=torchType, - use_safetensors=True, - variant=variant) - elif request.PipelineType == "FluxPipeline": - if fromSingleFile: - self.pipe = FluxPipeline.from_single_file(modelFile, - torch_dtype=torchType, - use_safetensors=True) - else: - self.pipe = FluxPipeline.from_pretrained( - request.Model, - torch_dtype=torch.bfloat16) - if request.LowVRAM: - self.pipe.enable_model_cpu_offload() - elif request.PipelineType == "FluxTransformer2DModel": - dtype = torch.bfloat16 - # specify from environment or default to "ChuckMcSneed/FLUX.1-dev" - bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev") - - transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype) - quantize(transformer, weights=qfloat8) - freeze(transformer) - text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) - quantize(text_encoder_2, weights=qfloat8) - freeze(text_encoder_2) - - self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype) - self.pipe.transformer = transformer - self.pipe.text_encoder_2 = text_encoder_2 - - if request.LowVRAM: - self.pipe.enable_model_cpu_offload() - elif request.PipelineType == "Lumina2Text2ImgPipeline": - self.pipe = Lumina2Text2ImgPipeline.from_pretrained( - request.Model, - torch_dtype=torch.bfloat16) - if request.LowVRAM: - self.pipe.enable_model_cpu_offload() - elif request.PipelineType == "SanaPipeline": - self.pipe = SanaPipeline.from_pretrained( - request.Model, - variant="bf16", - torch_dtype=torch.bfloat16) - self.pipe.vae.to(torch.bfloat16) - self.pipe.text_encoder.to(torch.bfloat16) - elif request.PipelineType == "WanPipeline": - # WAN2.2 pipeline requires special VAE handling - vae = AutoencoderKLWan.from_pretrained( - request.Model, - subfolder="vae", - torch_dtype=torch.float32 - ) - self.pipe = WanPipeline.from_pretrained( - request.Model, - vae=vae, - torch_dtype=torchType - ) - self.txt2vid = True # WAN2.2 is a text-to-video pipeline - elif request.PipelineType == "WanImageToVideoPipeline": - # WAN2.2 image-to-video pipeline - vae = AutoencoderKLWan.from_pretrained( - request.Model, - subfolder="vae", - torch_dtype=torch.float32 - ) - self.pipe = WanImageToVideoPipeline.from_pretrained( - request.Model, - vae=vae, - torch_dtype=torchType - ) - self.img2vid = True # WAN2.2 image-to-video pipeline + + # Load pipeline using dynamic loader + # Special cases that require custom initialization are handled first + self.pipe = self._load_pipeline( + request=request, + modelFile=modelFile, + fromSingleFile=fromSingleFile, + torchType=torchType, + variant=variant + ) if CLIPSKIP and request.CLIPSkip != 0: self.clip_skip = request.CLIPSkip diff --git a/backend/python/diffusers/diffusers_dynamic_loader.py b/backend/python/diffusers/diffusers_dynamic_loader.py new file mode 100644 index 000000000000..e47c7c2cf08b --- /dev/null +++ b/backend/python/diffusers/diffusers_dynamic_loader.py @@ -0,0 +1,538 @@ +""" +Dynamic Diffusers Pipeline Loader + +This module provides dynamic discovery and loading of diffusers pipelines at runtime, +eliminating the need for per-pipeline conditional statements. New pipelines added to +diffusers become available automatically without code changes. + +The module also supports discovering other diffusers classes like schedulers, models, +and other components, making it a generic solution for dynamic class loading. + +Usage: + from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines + + # Load by class name + pipe = load_diffusers_pipeline(class_name="StableDiffusionPipeline", model_id="...", torch_dtype=torch.float16) + + # Load by task alias + pipe = load_diffusers_pipeline(task="text-to-image", model_id="...", torch_dtype=torch.float16) + + # Load using model_id (infers from HuggingFace Hub if possible) + pipe = load_diffusers_pipeline(model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + + # Get list of available pipelines + available = get_available_pipelines() + + # Discover other diffusers classes (schedulers, models, etc.) + schedulers = discover_diffusers_classes("SchedulerMixin") + models = discover_diffusers_classes("ModelMixin") +""" + +import importlib +import re +import sys +from typing import Any, Dict, List, Optional, Tuple, Type + + +# Global cache for discovered pipelines - computed once per process +_pipeline_registry: Optional[Dict[str, Type]] = None +_task_aliases: Optional[Dict[str, List[str]]] = None + +# Global cache for other discovered class types +_class_registries: Dict[str, Dict[str, Type]] = {} + + +def _camel_to_kebab(name: str) -> str: + """ + Convert CamelCase to kebab-case. + + Examples: + StableDiffusionPipeline -> stable-diffusion-pipeline + StableDiffusionXLImg2ImgPipeline -> stable-diffusion-xl-img-2-img-pipeline + """ + # Insert hyphen before uppercase letters (but not at the start) + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1-\2', name) + # Insert hyphen before uppercase letters following lowercase letters or numbers + s2 = re.sub('([a-z0-9])([A-Z])', r'\1-\2', s1) + return s2.lower() + + +def _extract_task_keywords(class_name: str) -> List[str]: + """ + Extract task-related keywords from a pipeline class name. + + This function derives useful task aliases from the class name without + hardcoding per-pipeline branches. + + Returns a list of potential task aliases for this pipeline. + """ + aliases = [] + name_lower = class_name.lower() + + # Direct task mappings based on common patterns in class names + task_patterns = { + 'text2image': ['text-to-image', 'txt2img', 'text2image'], + 'texttoimage': ['text-to-image', 'txt2img', 'text2image'], + 'txt2img': ['text-to-image', 'txt2img', 'text2image'], + 'img2img': ['image-to-image', 'img2img', 'image2image'], + 'image2image': ['image-to-image', 'img2img', 'image2image'], + 'imagetoimage': ['image-to-image', 'img2img', 'image2image'], + 'img2video': ['image-to-video', 'img2vid', 'img2video'], + 'imagetovideo': ['image-to-video', 'img2vid', 'img2video'], + 'text2video': ['text-to-video', 'txt2vid', 'text2video'], + 'texttovideo': ['text-to-video', 'txt2vid', 'text2video'], + 'inpaint': ['inpainting', 'inpaint'], + 'depth2img': ['depth-to-image', 'depth2img'], + 'depthtoimage': ['depth-to-image', 'depth2img'], + 'controlnet': ['controlnet', 'control-net'], + 'upscale': ['upscaling', 'upscale', 'super-resolution'], + 'superresolution': ['upscaling', 'upscale', 'super-resolution'], + } + + # Check for each pattern in the class name + for pattern, task_aliases in task_patterns.items(): + if pattern in name_lower: + aliases.extend(task_aliases) + + # Also detect general pipeline types from the class name structure + # E.g., StableDiffusionPipeline -> stable-diffusion, flux -> flux + # Remove "Pipeline" suffix and convert to kebab case + if class_name.endswith('Pipeline'): + base_name = class_name[:-8] # Remove "Pipeline" + kebab_name = _camel_to_kebab(base_name) + aliases.append(kebab_name) + + # Extract model family name (e.g., "stable-diffusion" from "stable-diffusion-xl-img-2-img") + parts = kebab_name.split('-') + if len(parts) >= 2: + # Try the first two words as a family name + family = '-'.join(parts[:2]) + if family not in aliases: + aliases.append(family) + + # If no specific task pattern matched but class contains "Pipeline", add "text-to-image" as default + # since most diffusion pipelines support text-to-image generation + if 'text-to-image' not in aliases and 'image-to-image' not in aliases: + # Only add for pipelines that seem to be generation pipelines (not schedulers, etc.) + if 'pipeline' in name_lower and not any(x in name_lower for x in ['scheduler', 'processor', 'encoder']): + # Don't automatically add - let it be explicit + pass + + return list(set(aliases)) # Remove duplicates + + +def discover_diffusers_classes( + base_class_name: str, + include_base: bool = True +) -> Dict[str, Type]: + """ + Discover all subclasses of a given base class from diffusers. + + This function provides a generic way to discover any type of diffusers class, + not just pipelines. It can be used to discover schedulers, models, processors, + and other components. + + Args: + base_class_name: Name of the base class to search for subclasses + (e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin") + include_base: Whether to include the base class itself in results + + Returns: + Dict mapping class names to class objects + + Examples: + # Discover all pipeline classes + pipelines = discover_diffusers_classes("DiffusionPipeline") + + # Discover all scheduler classes + schedulers = discover_diffusers_classes("SchedulerMixin") + + # Discover all model classes + models = discover_diffusers_classes("ModelMixin") + + # Discover AutoPipeline classes + auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image") + """ + global _class_registries + + # Check cache first + if base_class_name in _class_registries: + return _class_registries[base_class_name] + + import diffusers + + # Try to get the base class from diffusers + base_class = None + try: + base_class = getattr(diffusers, base_class_name) + except AttributeError: + # Try to find in submodules + for submodule in ['schedulers', 'models', 'pipelines']: + try: + module = importlib.import_module(f'diffusers.{submodule}') + if hasattr(module, base_class_name): + base_class = getattr(module, base_class_name) + break + except (ImportError, ModuleNotFoundError): + continue + + if base_class is None: + raise ValueError(f"Could not find base class '{base_class_name}' in diffusers") + + registry: Dict[str, Type] = {} + + # Include base class if requested + if include_base: + registry[base_class_name] = base_class + + # Scan diffusers module for subclasses + for attr_name in dir(diffusers): + try: + attr = getattr(diffusers, attr_name) + if (isinstance(attr, type) and + issubclass(attr, base_class) and + (include_base or attr is not base_class)): + registry[attr_name] = attr + except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError): + continue + + # Cache the results + _class_registries[base_class_name] = registry + return registry + + +def get_available_classes(base_class_name: str) -> List[str]: + """ + Get a sorted list of all discovered class names for a given base class. + + Args: + base_class_name: Name of the base class (e.g., "SchedulerMixin") + + Returns: + Sorted list of discovered class names + """ + return sorted(discover_diffusers_classes(base_class_name).keys()) + + +def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]: + """ + Discover all subclasses of DiffusionPipeline from diffusers. + + This function uses the generic discover_diffusers_classes() internally + and adds pipeline-specific task alias generation. It also includes + AutoPipeline classes which are special utility classes for automatic + pipeline selection. + + Returns: + A tuple of (pipeline_registry, task_aliases) where: + - pipeline_registry: Dict mapping class names to class objects + - task_aliases: Dict mapping task aliases to lists of class names + """ + # Use the generic discovery function + pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True) + + # Also add AutoPipeline classes - these are special utility classes that are + # NOT subclasses of DiffusionPipeline but are commonly used + import diffusers + auto_pipeline_classes = [ + "AutoPipelineForText2Image", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + ] + for cls_name in auto_pipeline_classes: + try: + cls = getattr(diffusers, cls_name) + if cls is not None: + pipeline_registry[cls_name] = cls + except AttributeError: + # Class not available in this version of diffusers + pass + + # Generate task aliases for pipelines + task_aliases: Dict[str, List[str]] = {} + for attr_name in pipeline_registry: + if attr_name == "DiffusionPipeline": + continue # Skip base class for alias generation + + aliases = _extract_task_keywords(attr_name) + for alias in aliases: + if alias not in task_aliases: + task_aliases[alias] = [] + if attr_name not in task_aliases[alias]: + task_aliases[alias].append(attr_name) + + return pipeline_registry, task_aliases + + +def get_pipeline_registry() -> Dict[str, Type]: + """ + Get the cached pipeline registry. + + Returns a dictionary mapping pipeline class names to their class objects. + The registry is built on first access and cached for subsequent calls. + """ + global _pipeline_registry, _task_aliases + if _pipeline_registry is None: + _pipeline_registry, _task_aliases = _discover_pipelines() + return _pipeline_registry + + +def get_task_aliases() -> Dict[str, List[str]]: + """ + Get the cached task aliases dictionary. + + Returns a dictionary mapping task aliases (e.g., "text-to-image") to + lists of pipeline class names that support that task. + """ + global _pipeline_registry, _task_aliases + if _task_aliases is None: + _pipeline_registry, _task_aliases = _discover_pipelines() + return _task_aliases + + +def get_available_pipelines() -> List[str]: + """ + Get a sorted list of all discovered pipeline class names. + + Returns: + List of pipeline class names available for loading. + """ + return sorted(get_pipeline_registry().keys()) + + +def get_available_tasks() -> List[str]: + """ + Get a sorted list of all available task aliases. + + Returns: + List of task aliases (e.g., ["text-to-image", "image-to-image", ...]) + """ + return sorted(get_task_aliases().keys()) + + +def resolve_pipeline_class( + class_name: Optional[str] = None, + task: Optional[str] = None, + model_id: Optional[str] = None +) -> Type: + """ + Resolve a pipeline class from class_name, task, or model_id. + + Priority: + 1. If class_name is provided, look it up directly + 2. If task is provided, resolve through task aliases + 3. If model_id is provided, try to infer from HuggingFace Hub + + Args: + class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline") + task: Task alias (e.g., "text-to-image", "img2img") + model_id: HuggingFace model ID (e.g., "runwayml/stable-diffusion-v1-5") + + Returns: + The resolved pipeline class. + + Raises: + ValueError: If no pipeline could be resolved. + """ + registry = get_pipeline_registry() + aliases = get_task_aliases() + + # 1. Direct class name lookup + if class_name: + if class_name in registry: + return registry[class_name] + # Try case-insensitive match + for name, cls in registry.items(): + if name.lower() == class_name.lower(): + return cls + raise ValueError( + f"Unknown pipeline class '{class_name}'. " + f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}..." + ) + + # 2. Task alias lookup + if task: + task_lower = task.lower().replace('_', '-') + if task_lower in aliases: + # Return the first matching pipeline for this task + matching_classes = aliases[task_lower] + if matching_classes: + return registry[matching_classes[0]] + + # Try partial matching + for alias, classes in aliases.items(): + if task_lower in alias or alias in task_lower: + if classes: + return registry[classes[0]] + + raise ValueError( + f"Unknown task '{task}'. " + f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..." + ) + + # 3. Try to infer from HuggingFace Hub + if model_id: + try: + from huggingface_hub import model_info + info = model_info(model_id) + + # Check pipeline_tag + if hasattr(info, 'pipeline_tag') and info.pipeline_tag: + tag = info.pipeline_tag.lower().replace('_', '-') + if tag in aliases: + matching_classes = aliases[tag] + if matching_classes: + return registry[matching_classes[0]] + + # Check model card for hints + if hasattr(info, 'cardData') and info.cardData: + card = info.cardData + if 'pipeline_tag' in card: + tag = card['pipeline_tag'].lower().replace('_', '-') + if tag in aliases: + matching_classes = aliases[tag] + if matching_classes: + return registry[matching_classes[0]] + + except ImportError: + # huggingface_hub not available + pass + except (KeyError, AttributeError, ValueError, OSError): + # Model info lookup failed - common cases: + # - KeyError: Missing keys in model card + # - AttributeError: Missing attributes on model info + # - ValueError: Invalid model data + # - OSError: Network or file access issues + pass + + # Fallback: use DiffusionPipeline.from_pretrained which auto-detects + # DiffusionPipeline is always added to registry in _discover_pipelines (line 132) + # but use .get() with import fallback for extra safety + from diffusers import DiffusionPipeline + return registry.get('DiffusionPipeline', DiffusionPipeline) + + raise ValueError( + "Must provide at least one of: class_name, task, or model_id. " + f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}... " + f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..." + ) + + +def load_diffusers_pipeline( + class_name: Optional[str] = None, + task: Optional[str] = None, + model_id: Optional[str] = None, + from_single_file: bool = False, + **kwargs +) -> Any: + """ + Load a diffusers pipeline dynamically. + + This function resolves the appropriate pipeline class based on the provided + parameters and instantiates it with the given kwargs. + + Args: + class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline") + task: Task alias (e.g., "text-to-image", "img2img") + model_id: HuggingFace model ID or local path + from_single_file: If True, use from_single_file() instead of from_pretrained() + **kwargs: Additional arguments passed to from_pretrained() or from_single_file() + + Returns: + An instantiated pipeline object. + + Raises: + ValueError: If no pipeline could be resolved. + Exception: If pipeline loading fails. + + Examples: + # Load by class name + pipe = load_diffusers_pipeline( + class_name="StableDiffusionPipeline", + model_id="runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16 + ) + + # Load by task + pipe = load_diffusers_pipeline( + task="text-to-image", + model_id="runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16 + ) + + # Load from single file + pipe = load_diffusers_pipeline( + class_name="StableDiffusionPipeline", + model_id="/path/to/model.safetensors", + from_single_file=True, + torch_dtype=torch.float16 + ) + """ + # Resolve the pipeline class + pipeline_class = resolve_pipeline_class( + class_name=class_name, + task=task, + model_id=model_id + ) + + # If no model_id provided but we have a class, we can't load + if model_id is None: + raise ValueError("model_id is required to load a pipeline") + + # Load the pipeline + try: + if from_single_file: + # Check if the class has from_single_file method + if hasattr(pipeline_class, 'from_single_file'): + return pipeline_class.from_single_file(model_id, **kwargs) + else: + raise ValueError( + f"Pipeline class {pipeline_class.__name__} does not support from_single_file(). " + f"Use from_pretrained() instead." + ) + else: + return pipeline_class.from_pretrained(model_id, **kwargs) + + except Exception as e: + # Provide helpful error message + available = get_available_pipelines() + raise RuntimeError( + f"Failed to load pipeline '{pipeline_class.__name__}' from '{model_id}': {e}\n" + f"Available pipelines: {', '.join(available[:20])}..." + ) from e + + +def get_pipeline_info(class_name: str) -> Dict[str, Any]: + """ + Get information about a specific pipeline class. + + Args: + class_name: The pipeline class name + + Returns: + Dictionary with pipeline information including: + - name: Class name + - aliases: List of task aliases + - supports_single_file: Whether from_single_file() is available + - docstring: Class docstring (if available) + """ + registry = get_pipeline_registry() + aliases = get_task_aliases() + + if class_name not in registry: + raise ValueError(f"Unknown pipeline: {class_name}") + + cls = registry[class_name] + + # Find all aliases for this pipeline + pipeline_aliases = [] + for alias, classes in aliases.items(): + if class_name in classes: + pipeline_aliases.append(alias) + + return { + 'name': class_name, + 'aliases': pipeline_aliases, + 'supports_single_file': hasattr(cls, 'from_single_file'), + 'docstring': cls.__doc__[:200] if cls.__doc__ else None + } diff --git a/backend/python/diffusers/test.py b/backend/python/diffusers/test.py index 14b2e175e2c0..5befeca0a99a 100644 --- a/backend/python/diffusers/test.py +++ b/backend/python/diffusers/test.py @@ -1,15 +1,26 @@ """ -A test script to test the gRPC service +A test script to test the gRPC service and dynamic loader """ import unittest import subprocess import time -import backend_pb2 -import backend_pb2_grpc +from unittest.mock import patch, MagicMock -import grpc +# Import dynamic loader for testing (these don't need gRPC) +import diffusers_dynamic_loader as loader +from diffusers import DiffusionPipeline, StableDiffusionPipeline +# Try to import gRPC modules - may not be available during unit testing +try: + import grpc + import backend_pb2 + import backend_pb2_grpc + GRPC_AVAILABLE = True +except ImportError: + GRPC_AVAILABLE = False + +@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available") class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service @@ -82,3 +93,222 @@ def test(self): self.fail("Image gen service failed") finally: self.tearDown() + + +class TestDiffusersDynamicLoader(unittest.TestCase): + """Test cases for the diffusers dynamic loader functionality.""" + + @classmethod + def setUpClass(cls): + """Set up test fixtures - clear caches to ensure fresh discovery.""" + # Reset the caches to ensure fresh discovery + loader._pipeline_registry = None + loader._task_aliases = None + + def test_camel_to_kebab_conversion(self): + """Test CamelCase to kebab-case conversion.""" + test_cases = [ + ("StableDiffusionPipeline", "stable-diffusion-pipeline"), + ("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"), + ("FluxPipeline", "flux-pipeline"), + ("DiffusionPipeline", "diffusion-pipeline"), + ] + for input_val, expected in test_cases: + with self.subTest(input=input_val): + result = loader._camel_to_kebab(input_val) + self.assertEqual(result, expected) + + def test_extract_task_keywords(self): + """Test task keyword extraction from class names.""" + # Test text-to-image detection + aliases = loader._extract_task_keywords("StableDiffusionPipeline") + self.assertIn("stable-diffusion", aliases) + + # Test img2img detection + aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline") + self.assertIn("image-to-image", aliases) + self.assertIn("img2img", aliases) + + # Test inpainting detection + aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline") + self.assertIn("inpainting", aliases) + self.assertIn("inpaint", aliases) + + # Test depth2img detection + aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline") + self.assertIn("depth-to-image", aliases) + + def test_discover_pipelines_finds_known_classes(self): + """Test that pipeline discovery finds at least one known pipeline class.""" + registry = loader.get_pipeline_registry() + + # Check that the registry is not empty + self.assertGreater(len(registry), 0, "Pipeline registry should not be empty") + + # Check for known pipeline classes + known_pipelines = [ + "StableDiffusionPipeline", + "DiffusionPipeline", + ] + + for pipeline_name in known_pipelines: + with self.subTest(pipeline=pipeline_name): + self.assertIn( + pipeline_name, + registry, + f"Expected to find {pipeline_name} in registry" + ) + + def test_discover_pipelines_caches_results(self): + """Test that pipeline discovery results are cached.""" + # Get registry twice + registry1 = loader.get_pipeline_registry() + registry2 = loader.get_pipeline_registry() + + # Should be the same object (cached) + self.assertIs(registry1, registry2, "Registry should be cached") + + def test_get_available_pipelines(self): + """Test getting list of available pipelines.""" + available = loader.get_available_pipelines() + + # Should return a list + self.assertIsInstance(available, list) + + # Should contain known pipelines + self.assertIn("StableDiffusionPipeline", available) + self.assertIn("DiffusionPipeline", available) + + # Should be sorted + self.assertEqual(available, sorted(available)) + + def test_get_available_tasks(self): + """Test getting list of available task aliases.""" + tasks = loader.get_available_tasks() + + # Should return a list + self.assertIsInstance(tasks, list) + + # Should be sorted + self.assertEqual(tasks, sorted(tasks)) + + def test_resolve_pipeline_class_by_name(self): + """Test resolving pipeline class by exact name.""" + cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") + self.assertEqual(cls, StableDiffusionPipeline) + + def test_resolve_pipeline_class_by_name_case_insensitive(self): + """Test that class name resolution is case-insensitive.""" + cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") + cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline") + self.assertEqual(cls1, cls2) + + def test_resolve_pipeline_class_by_task(self): + """Test resolving pipeline class by task alias.""" + # Get the registry to find available tasks + aliases = loader.get_task_aliases() + + # Test with a common task that should be available + if "stable-diffusion" in aliases: + cls = loader.resolve_pipeline_class(task="stable-diffusion") + self.assertIsNotNone(cls) + + def test_resolve_pipeline_class_unknown_name_raises(self): + """Test that resolving unknown class name raises ValueError with helpful message.""" + with self.assertRaises(ValueError) as ctx: + loader.resolve_pipeline_class(class_name="NonExistentPipeline") + + # Check that error message includes available pipelines + error_msg = str(ctx.exception) + self.assertIn("Unknown pipeline class", error_msg) + self.assertIn("Available pipelines", error_msg) + + def test_resolve_pipeline_class_unknown_task_raises(self): + """Test that resolving unknown task raises ValueError with helpful message.""" + with self.assertRaises(ValueError) as ctx: + loader.resolve_pipeline_class(task="nonexistent-task-xyz") + + # Check that error message includes available tasks + error_msg = str(ctx.exception) + self.assertIn("Unknown task", error_msg) + self.assertIn("Available tasks", error_msg) + + def test_resolve_pipeline_class_no_params_raises(self): + """Test that calling with no parameters raises helpful ValueError.""" + with self.assertRaises(ValueError) as ctx: + loader.resolve_pipeline_class() + + error_msg = str(ctx.exception) + self.assertIn("Must provide at least one of", error_msg) + + def test_get_pipeline_info(self): + """Test getting pipeline information.""" + info = loader.get_pipeline_info("StableDiffusionPipeline") + + self.assertEqual(info['name'], "StableDiffusionPipeline") + self.assertIsInstance(info['aliases'], list) + self.assertIsInstance(info['supports_single_file'], bool) + + def test_get_pipeline_info_unknown_raises(self): + """Test that getting info for unknown pipeline raises ValueError.""" + with self.assertRaises(ValueError) as ctx: + loader.get_pipeline_info("NonExistentPipeline") + + self.assertIn("Unknown pipeline", str(ctx.exception)) + + def test_discover_diffusers_classes_pipelines(self): + """Test generic class discovery for DiffusionPipeline.""" + classes = loader.discover_diffusers_classes("DiffusionPipeline") + + # Should return a dict + self.assertIsInstance(classes, dict) + + # Should contain known pipeline classes + self.assertIn("DiffusionPipeline", classes) + self.assertIn("StableDiffusionPipeline", classes) + + def test_discover_diffusers_classes_caches_results(self): + """Test that class discovery results are cached.""" + classes1 = loader.discover_diffusers_classes("DiffusionPipeline") + classes2 = loader.discover_diffusers_classes("DiffusionPipeline") + + # Should be the same object (cached) + self.assertIs(classes1, classes2) + + def test_discover_diffusers_classes_exclude_base(self): + """Test discovering classes without base class.""" + classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False) + + # Should still contain subclasses + self.assertIn("StableDiffusionPipeline", classes) + + def test_get_available_classes(self): + """Test getting list of available classes for a base class.""" + classes = loader.get_available_classes("DiffusionPipeline") + + # Should return a sorted list + self.assertIsInstance(classes, list) + self.assertEqual(classes, sorted(classes)) + + # Should contain known classes + self.assertIn("StableDiffusionPipeline", classes) + + +class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase): + """Test cases using mocks to test edge cases.""" + + def test_load_pipeline_requires_model_id(self): + """Test that load_diffusers_pipeline requires model_id.""" + with self.assertRaises(ValueError) as ctx: + loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline") + + self.assertIn("model_id is required", str(ctx.exception)) + + def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self): + """Test that resolving with only model_id falls back to DiffusionPipeline.""" + # When model_id is provided, if hub lookup is not successful, + # should fall back to DiffusionPipeline. + # This tests the fallback behavior - the actual hub lookup may succeed + # or fail depending on network, but the fallback path should work. + cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model") + self.assertEqual(cls, DiffusionPipeline)