From b2b660f76f9d869a2d1e0ebcfd732b646e6241d9 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 14:26:07 +0000 Subject: [PATCH 01/92] Studio: add local diffusion image generation page Backend - core/inference/diffusion.py: DiffusionBackend singleton that loads diffusion GGUFs from Hugging Face via diffusers.GGUFQuantizationConfig and runs them on the active CUDA / MPS / CPU device. Supports FLUX.2, FLUX.2 klein, FLUX.1, Qwen-Image, Stable Diffusion 3, and SDXL. - routes/inference.py: POST /api/inference/images/load, POST /api/inference/images/generate, POST /api/inference/images/unload, GET /api/inference/images/status mirroring the llama-server lifecycle. - models/inference.py: DiffusionLoadRequest, DiffusionGenerateRequest, DiffusionGenerateResponse pydantic schemas with prompt / step / size validation up front so callers get clear 422s rather than VAE crashes. - requirements/no-torch-runtime.txt: pin gguf alongside the existing diffusers entry so GGUFQuantizationConfig works out of the box. - tests/test_diffusion_backend.py + tests/test_diffusion_routes.py: 27 unit tests covering family detection, validation, lifecycle, and the full FastAPI round trip with the backend stubbed. No torch / diffusers / GPU required to run. Frontend - features/images/: standalone images-page.tsx with curated model picker (FLUX.2 klein 4B / 9B, FLUX.2 dev, FLUX.1 dev), HF token field, prompt + negative prompt, resolution presets, steps + guidance sliders, seed input, and a result gallery that renders base64 PNGs inline. - app/routes/images.tsx: lazy /images route wired into router.tsx. - components/app-sidebar.tsx: PaintBrush02Icon nav item between Recipes and Export, hidden in chat-only mode. --- studio/backend/core/inference/diffusion.py | 480 ++++++++++++++++++ studio/backend/models/inference.py | 67 +++ .../backend/requirements/no-torch-runtime.txt | 3 + studio/backend/routes/inference.py | 127 +++++ .../backend/tests/test_diffusion_backend.py | 396 +++++++++++++++ studio/backend/tests/test_diffusion_routes.py | 190 +++++++ studio/frontend/src/app/router.tsx | 2 + studio/frontend/src/app/routes/images.tsx | 21 + .../frontend/src/components/app-sidebar.tsx | 13 + studio/frontend/src/features/images/api.ts | 105 ++++ .../src/features/images/images-page.tsx | 425 ++++++++++++++++ studio/frontend/src/features/images/index.ts | 5 + 12 files changed, 1834 insertions(+) create mode 100644 studio/backend/core/inference/diffusion.py create mode 100644 studio/backend/tests/test_diffusion_backend.py create mode 100644 studio/backend/tests/test_diffusion_routes.py create mode 100644 studio/frontend/src/app/routes/images.tsx create mode 100644 studio/frontend/src/features/images/api.ts create mode 100644 studio/frontend/src/features/images/images-page.tsx create mode 100644 studio/frontend/src/features/images/index.ts diff --git a/studio/backend/core/inference/diffusion.py b/studio/backend/core/inference/diffusion.py new file mode 100644 index 0000000000..c44f132142 --- /dev/null +++ b/studio/backend/core/inference/diffusion.py @@ -0,0 +1,480 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +"""Diffusion image generation backend. + +Loads Hugging Face diffusion checkpoints in either the standard +``diffusers`` layout or the single-file GGUF layout published under +``unsloth/*-GGUF`` (Flux 2, Flux 2 Klein, Qwen-Image, SD3, SDXL, ...). +GGUF files are dynamically dequantised on-device via +``diffusers.GGUFQuantizationConfig``, then the rest of the pipeline +(VAE, text encoders, scheduler) is pulled from the matching ``diffusers`` +repo so end users only ever need one local file plus the metadata repo. + +The module is intentionally torch-only: it never spawns a subprocess and +shares the active CUDA / MPS device with the rest of Studio. The cost of +not having a separate process is that loading a diffusion model and a +GGUF chat model at the same time can OOM on consumer GPUs; the routes +layer must therefore swap between the two as needed (the orchestrator +unloads llama-server before any diffusion load on hosts with < 24 GB). + +The class deliberately exposes a small, llama-cpp-style surface: + + load_model(repo_id, ...) + generate_image(prompt, ...) -> PIL.Image + unload_model() + status() -> dict + +so the route layer at ``studio/backend/routes/inference.py`` can mirror +the existing llama-server lifecycle (probe + load + generate + unload) +without learning a second API. +""" + +from __future__ import annotations + +import asyncio +import gc +import io +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +from loggers import get_logger + +logger = get_logger(__name__) + + +# ─── Pipeline registry ──────────────────────────────────────────────── +# +# Keep this list narrow on purpose: only ship the small text-to-image +# families with first-class GGUF coverage on the Hub. Anything else is +# either video (LTX*, Wan) or research-grade (Sana, SD3.5) and can be +# added once it has a working GGUF release plus a smoke test. +# +# Each entry maps a substring of the loaded repo id (case-insensitive) +# to the (pipeline_class_name, transformer_class_name, default base +# repo for missing pieces). ``base_repo`` is what we pass to +# ``Pipeline.from_pretrained`` to pick up the VAE + text encoders when +# the user gave us a GGUF-only repo. The base_repo is documented to the +# user via ``status()`` so they understand why a second download fires. + +@dataclass(frozen = True) +class DiffusionFamily: + name: str + pipeline_class: str + transformer_class: str + base_repo: str + # Optional: list of HF "trigger" substrings besides ``name`` that map + # to this family (e.g. "flux1-dev" plus "flux.1-dev"). Lowercased. + aliases: tuple[str, ...] = field(default_factory = tuple) + + +_FAMILIES: tuple[DiffusionFamily, ...] = ( + DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-klein", + aliases = ("flux2-klein", "flux-2-klein", "flux.2.klein"), + ), + DiffusionFamily( + name = "flux.2", + pipeline_class = "Flux2Pipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-dev", + aliases = ("flux2-dev", "flux-2-dev", "flux.2.dev"), + ), + DiffusionFamily( + name = "flux.1", + pipeline_class = "FluxPipeline", + transformer_class = "FluxTransformer2DModel", + base_repo = "black-forest-labs/FLUX.1-dev", + aliases = ("flux1-dev", "flux-1-dev", "flux.1.dev", "flux-dev"), + ), + DiffusionFamily( + name = "qwen-image", + pipeline_class = "QwenImagePipeline", + transformer_class = "QwenImageTransformer2DModel", + base_repo = "Qwen/Qwen-Image", + aliases = ("qwenimage", "qwen_image"), + ), + DiffusionFamily( + name = "stable-diffusion-3", + pipeline_class = "StableDiffusion3Pipeline", + transformer_class = "SD3Transformer2DModel", + base_repo = "stabilityai/stable-diffusion-3-medium-diffusers", + aliases = ("sd3-medium", "stable-diffusion-3-medium", "sd3.5"), + ), + DiffusionFamily( + name = "stable-diffusion-xl", + pipeline_class = "StableDiffusionXLPipeline", + transformer_class = "", # SDXL uses a UNet, not a transformer + base_repo = "stabilityai/stable-diffusion-xl-base-1.0", + aliases = ("sdxl",), + ), +) + + +def detect_family(repo_id: str, *, override_family: Optional[str] = None) -> Optional[DiffusionFamily]: + """Return the diffusion family matching ``repo_id``. + + Matching is substring-based and case-insensitive. ``override_family`` + bypasses substring matching and looks up by ``DiffusionFamily.name``. + Returns ``None`` when no family applies so callers can surface a clear + "unsupported model" error rather than guessing wrong. + """ + if override_family: + wanted = override_family.strip().lower() + for fam in _FAMILIES: + if fam.name == wanted: + return fam + return None + needle = (repo_id or "").lower() + if not needle: + return None + for fam in _FAMILIES: + if fam.name in needle: + return fam + for alias in fam.aliases: + if alias and alias in needle: + return fam + return None + + +def supported_families() -> list[dict[str, str]]: + """Public-facing list of families for ``/api/inference/images/status``.""" + return [ + { + "name": fam.name, + "pipeline_class": fam.pipeline_class, + "base_repo": fam.base_repo, + } + for fam in _FAMILIES + ] + + +# ─── Backend ────────────────────────────────────────────────────────── + + +class DiffusionBackend: + """Singleton-style diffusion backend. + + One pipeline at a time; ``load_model`` swaps the previous one out. + Generation is mutex'd so concurrent requests serialise rather than + racing GPU memory. + """ + + def __init__(self) -> None: + self._pipe: Any = None + self._lock = threading.Lock() + self._family: Optional[DiffusionFamily] = None + self._repo_id: Optional[str] = None + self._gguf_path: Optional[str] = None + self._base_repo: Optional[str] = None + self._device: Optional[str] = None + self._dtype: Optional[str] = None + self._loaded_at: Optional[float] = None + self._loading: bool = False + self._last_error: Optional[str] = None + + # ── lifecycle ───────────────────────────────────────────────── + + @property + def is_loaded(self) -> bool: + return self._pipe is not None + + @property + def repo_id(self) -> Optional[str]: + return self._repo_id + + def status(self) -> dict[str, Any]: + return { + "is_loaded": self.is_loaded, + "is_loading": self._loading, + "repo_id": self._repo_id, + "family": self._family.name if self._family else None, + "pipeline_class": self._family.pipeline_class if self._family else None, + "base_repo": self._base_repo, + "gguf_path": self._gguf_path, + "device": self._device, + "dtype": self._dtype, + "loaded_at": self._loaded_at, + "last_error": self._last_error, + "supported_families": supported_families(), + } + + def _pick_device_and_dtype(self) -> tuple[str, "Any"]: + """Pick (device, dtype) for the current host. + + CUDA-first because that is the only path our diffusion GGUFs are + validated on. On macOS we use MPS in float16 to keep the pipeline + on the Metal GPU. CPU is allowed only as a last resort because + running FLUX on CPU is unusably slow (> 10 minutes per image). + """ + import torch + + if torch.cuda.is_available(): + return "cuda", torch.bfloat16 + if hasattr(torch, "backends") and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return "mps", torch.float16 + return "cpu", torch.float32 + + def load_model( + self, + repo_id: str, + *, + gguf_filename: Optional[str] = None, + base_repo: Optional[str] = None, + hf_token: Optional[str] = None, + family_override: Optional[str] = None, + enable_model_cpu_offload: bool = True, + ) -> dict[str, Any]: + """Load a diffusion model. + + ``repo_id`` is the Hugging Face repo id of either a GGUF-only + repo (e.g. ``unsloth/FLUX.2-klein-4B-GGUF``) or a full diffusers + repo (e.g. ``black-forest-labs/FLUX.2-klein``). When the repo + contains a GGUF, ``gguf_filename`` picks which quant to load; + otherwise diffusers' standard config-driven load runs. + + ``base_repo`` overrides the auto-detected diffusers base used + for VAE / text encoders. ``family_override`` short-circuits the + substring matcher when an exotic repo name confuses it. + + Raises ``RuntimeError`` on failure with a user-facing message; + the previous pipeline (if any) stays loaded so a failed swap + does not leave Studio in an unusable state. + """ + from huggingface_hub import hf_hub_download + import diffusers + import torch + + fam = detect_family(repo_id, override_family = family_override) + if fam is None: + raise RuntimeError( + f"Could not infer a diffusion family for '{repo_id}'. " + "Pass family_override = 'flux.2-klein' / 'flux.2' / " + "'flux.1' / 'qwen-image' / 'stable-diffusion-3' / " + "'stable-diffusion-xl' to disambiguate." + ) + + device, dtype = self._pick_device_and_dtype() + + with self._lock: + self._loading = True + self._last_error = None + try: + pipeline_cls = getattr(diffusers, fam.pipeline_class, None) + if pipeline_cls is None: + raise RuntimeError( + f"diffusers {diffusers.__version__} has no " + f"{fam.pipeline_class}; upgrade diffusers and retry." + ) + transformer_cls = ( + getattr(diffusers, fam.transformer_class, None) + if fam.transformer_class + else None + ) + + effective_base = base_repo or fam.base_repo + logger.info( + "Loading diffusion model %s (family=%s, device=%s, dtype=%s, base=%s)", + repo_id, + fam.name, + device, + dtype, + effective_base, + ) + + transformer = None + local_gguf_path: Optional[str] = None + if gguf_filename: + if transformer_cls is None: + raise RuntimeError( + f"Family {fam.name} does not have a GGUF transformer " + "path; load the full repo instead." + ) + local_gguf_path = hf_hub_download( + repo_id = repo_id, + filename = gguf_filename, + token = hf_token, + ) + quant_config = diffusers.GGUFQuantizationConfig(compute_dtype = dtype) + transformer = transformer_cls.from_single_file( + local_gguf_path, + quantization_config = quant_config, + torch_dtype = dtype, + ) + + pipe_kwargs: dict[str, Any] = {"torch_dtype": dtype} + if transformer is not None: + pipe_kwargs["transformer"] = transformer + if hf_token: + pipe_kwargs["token"] = hf_token + + pipe = pipeline_cls.from_pretrained(effective_base, **pipe_kwargs) + if enable_model_cpu_offload and device == "cuda": + pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + # Drop the old pipeline only after the new one is in place. + old = self._pipe + with self._lock: + self._pipe = pipe + self._family = fam + self._repo_id = repo_id + self._gguf_path = local_gguf_path + self._base_repo = effective_base + self._device = device + self._dtype = str(dtype).replace("torch.", "") + self._loaded_at = time.time() + _release(old) + + return self.status() + except Exception as exc: + with self._lock: + self._last_error = str(exc) + logger.exception("Diffusion load failed for %s", repo_id) + raise RuntimeError(f"Failed to load diffusion model: {exc}") from exc + finally: + with self._lock: + self._loading = False + + def unload_model(self) -> dict[str, Any]: + with self._lock: + old = self._pipe + self._pipe = None + self._family = None + self._repo_id = None + self._gguf_path = None + self._base_repo = None + self._device = None + self._dtype = None + self._loaded_at = None + _release(old) + return {"is_loaded": False} + + # ── generation ──────────────────────────────────────────────── + + def generate_image( + self, + *, + prompt: str, + negative_prompt: Optional[str] = None, + num_inference_steps: int = 24, + guidance_scale: float = 3.5, + width: int = 1024, + height: int = 1024, + seed: Optional[int] = None, + ) -> "Any": + """Generate a single PIL image and return it. + + The mutex is held for the entire call: diffusion pipelines are + not thread-safe, and overlapping ``__call__``s on a shared + pipeline frequently corrupt their internal scheduler state. + """ + if not prompt or not prompt.strip(): + raise ValueError("prompt is empty") + if num_inference_steps < 1 or num_inference_steps > 200: + raise ValueError("num_inference_steps must be in [1, 200]") + if width <= 0 or height <= 0 or width > 2048 or height > 2048: + raise ValueError("width and height must be in (0, 2048]") + # Snap to a multiple of 8: Flux / SD pipelines require it and a + # silent crash deep in the VAE is much worse than a clear error + # message up front. + if width % 8 or height % 8: + raise ValueError("width and height must be multiples of 8") + + import torch + + with self._lock: + if self._pipe is None: + raise RuntimeError("No diffusion model is loaded.") + pipe = self._pipe + device = self._device or "cpu" + + generator = None + if seed is not None: + # Match the device of the pipeline so determinism holds + # across reload cycles. For CPU offload, the noise still + # has to live on the device the diffusion forward runs on. + gen_device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu" + generator = torch.Generator(device = gen_device).manual_seed(int(seed)) + + call_kwargs: dict[str, Any] = { + "prompt": prompt, + "num_inference_steps": int(num_inference_steps), + "guidance_scale": float(guidance_scale), + "width": int(width), + "height": int(height), + } + if negative_prompt is not None and negative_prompt.strip(): + call_kwargs["negative_prompt"] = negative_prompt + if generator is not None: + call_kwargs["generator"] = generator + + out = pipe(**call_kwargs) + images = getattr(out, "images", None) or [] + if not images: + raise RuntimeError("Diffusion pipeline returned no images.") + return images[0] + + +def encode_png_base64(pil_image: "Any") -> str: + """Encode a PIL image to base64-encoded PNG.""" + import base64 + + buf = io.BytesIO() + pil_image.save(buf, format = "PNG", optimize = True) + return base64.b64encode(buf.getvalue()).decode("ascii") + + +# ─── Helpers ────────────────────────────────────────────────────────── + + +def _release(obj: Any) -> None: + """Best-effort GPU-memory release for a pipeline being swapped out.""" + if obj is None: + return + try: + del obj + except Exception: + pass + gc.collect() + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + + +# ─── Module-level singleton ─────────────────────────────────────────── + + +_singleton: Optional[DiffusionBackend] = None +_singleton_lock = threading.Lock() + + +def get_diffusion_backend() -> DiffusionBackend: + """Return the process-wide diffusion backend (lazy-instantiated).""" + global _singleton + if _singleton is None: + with _singleton_lock: + if _singleton is None: + _singleton = DiffusionBackend() + return _singleton + + +async def async_generate( + backend: DiffusionBackend, + **kwargs: Any, +) -> "Any": + """Run ``generate_image`` in the default executor so route handlers + do not block the event loop for the 5-30 s a diffusion step takes.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, lambda: backend.generate_image(**kwargs)) diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index b5626951c4..f26220cc50 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -1421,3 +1421,70 @@ class AnthropicMessagesResponse(BaseModel): stop_reason: Optional[str] = None stop_sequence: Optional[str] = None usage: AnthropicUsage = Field(default_factory = AnthropicUsage) + + +# ── Diffusion image generation ──────────────────────────────────── + + +class DiffusionLoadRequest(BaseModel): + """Load a diffusion image-generation model. + + repo_id is the HF repo (either GGUF-only or full diffusers layout). + gguf_filename selects the quant when repo_id is a GGUF repo. + base_repo overrides the auto-picked diffusers base used for the + VAE / text encoders when loading a GGUF-only repo. + """ + + repo_id: str = Field(..., description = "HF repo id") + gguf_filename: Optional[str] = Field( + None, description = "GGUF filename inside repo_id (Q4_K_S, Q8_0, ...)" + ) + base_repo: Optional[str] = Field( + None, + description = "Diffusers base repo to source VAE + text encoders from", + ) + family: Optional[str] = Field( + None, + description = "Force pipeline family: flux.2-klein | flux.2 | flux.1 | qwen-image | stable-diffusion-3 | stable-diffusion-xl", + ) + hf_token: Optional[str] = Field( + None, description = "HuggingFace token for gated models" + ) + enable_model_cpu_offload: bool = Field( + True, + description = "Offload submodules to CPU between forwards. Trades a small speed hit for ~6 GB less VRAM on FLUX-class models.", + ) + + +class DiffusionGenerateRequest(BaseModel): + """Generate a single image from the currently-loaded diffusion model.""" + + prompt: str = Field(..., min_length = 1, max_length = 4000) + negative_prompt: Optional[str] = Field(None, max_length = 4000) + num_inference_steps: int = Field(24, ge = 1, le = 200) + guidance_scale: float = Field(3.5, ge = 0.0, le = 20.0) + width: int = Field(1024, ge = 64, le = 2048) + height: int = Field(1024, ge = 64, le = 2048) + seed: Optional[int] = Field( + None, description = "Deterministic seed for reproducible outputs" + ) + + @field_validator("width", "height") + @classmethod + def _multiple_of_eight(cls, v: int) -> int: + if v % 8: + raise ValueError("width and height must be multiples of 8") + return v + + +class DiffusionGenerateResponse(BaseModel): + image_b64: str = Field(..., description = "Base64-encoded PNG") + image_mime: str = "image/png" + width: int + height: int + num_inference_steps: int + guidance_scale: float + seed: Optional[int] = None + duration_ms: int + model: Optional[str] = None + family: Optional[str] = None diff --git a/studio/backend/requirements/no-torch-runtime.txt b/studio/backend/requirements/no-torch-runtime.txt index 85294114b1..fa3f33757e 100644 --- a/studio/backend/requirements/no-torch-runtime.txt +++ b/studio/backend/requirements/no-torch-runtime.txt @@ -46,6 +46,9 @@ peft>=0.18.0,!=0.11.0 huggingface_hub>=0.34.0 hf_transfer diffusers +# Required by diffusers.GGUFQuantizationConfig (used by the Images page +# to load FLUX.2 / FLUX.1 / Qwen-Image / SDXL GGUFs from the Hub). +gguf # Transitive deps required because this file is installed with --no-deps. # Without these, `from transformers import AutoConfig` fails at import time. diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index bf92055929..fc9cbf9f88 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -213,6 +213,9 @@ def _friendly_error(exc: Exception) -> str: ListOpenAIContainersResponse, OpenAIContainerRequest, OpenAIContainerSummary, + DiffusionLoadRequest, + DiffusionGenerateRequest, + DiffusionGenerateResponse, ) from core.inference.anthropic_compat import ( anthropic_messages_to_openai, @@ -1584,6 +1587,130 @@ async def generate_audio( ) +# ===================================================================== +# Diffusion image generation (/images/*) +# ===================================================================== +# +# Lifecycle mirrors the GGUF chat backend: explicit load -> generate -> +# unload. Diffusion pipelines compete for the same GPU as llama-server, +# so callers on < 24 GB GPUs should unload the chat model first. + + +def _get_diffusion_backend(): + """Lazy import so non-diffusion installs do not pay the diffusers + cost at process start. The backend itself is a process-wide + singleton; reusing it across requests keeps pipeline state alive.""" + from core.inference.diffusion import get_diffusion_backend + + return get_diffusion_backend() + + +@router.post("/images/load") +async def diffusion_load( + payload: DiffusionLoadRequest, + current_subject: str = Depends(get_current_subject), +): + """Load a diffusion image-generation model. + + Pass either a full diffusers repo or a GGUF-only repo plus the + desired ``gguf_filename``. Returns the new status payload (same + shape as ``/images/status``). + """ + backend = _get_diffusion_backend() + try: + status = await asyncio.get_event_loop().run_in_executor( + None, + lambda: backend.load_model( + repo_id = payload.repo_id, + gguf_filename = payload.gguf_filename, + base_repo = payload.base_repo, + family_override = payload.family, + hf_token = payload.hf_token, + enable_model_cpu_offload = payload.enable_model_cpu_offload, + ), + ) + return JSONResponse(content = status) + except RuntimeError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) + except Exception as exc: + logger.exception("Diffusion load failed") + raise HTTPException(status_code = 500, detail = str(exc)) + + +@router.post("/images/unload") +async def diffusion_unload( + current_subject: str = Depends(get_current_subject), +): + """Unload the current diffusion model and free GPU memory.""" + backend = _get_diffusion_backend() + return backend.unload_model() + + +@router.get("/images/status") +async def diffusion_status( + current_subject: str = Depends(get_current_subject), +): + """Return diffusion backend status (loaded, family, device, etc.).""" + backend = _get_diffusion_backend() + return backend.status() + + +@router.post("/images/generate", response_model = DiffusionGenerateResponse) +async def diffusion_generate( + payload: DiffusionGenerateRequest, + current_subject: str = Depends(get_current_subject), +): + """Generate a single image from the loaded diffusion model. + + Returns a base64 PNG plus the generation parameters that produced + it so the frontend can render the result and the user can reproduce + it via the same seed. + """ + backend = _get_diffusion_backend() + if not backend.is_loaded: + raise HTTPException( + status_code = 400, + detail = "No diffusion model is loaded. POST /api/inference/images/load first.", + ) + + start = time.time() + try: + from core.inference.diffusion import async_generate, encode_png_base64 + + image = await async_generate( + backend, + prompt = payload.prompt, + negative_prompt = payload.negative_prompt, + num_inference_steps = payload.num_inference_steps, + guidance_scale = payload.guidance_scale, + width = payload.width, + height = payload.height, + seed = payload.seed, + ) + except ValueError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) + except RuntimeError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) + except Exception as exc: + logger.exception("Diffusion generation failed") + raise HTTPException(status_code = 500, detail = str(exc)) + + duration_ms = int((time.time() - start) * 1000) + status = backend.status() + return DiffusionGenerateResponse( + image_b64 = encode_png_base64(image), + image_mime = "image/png", + width = payload.width, + height = payload.height, + num_inference_steps = payload.num_inference_steps, + guidance_scale = payload.guidance_scale, + seed = payload.seed, + duration_ms = duration_ms, + model = status.get("repo_id"), + family = status.get("family"), + ) + + # ===================================================================== # OpenAI-Compatible Chat Completions (/chat/completions) # ===================================================================== diff --git a/studio/backend/tests/test_diffusion_backend.py b/studio/backend/tests/test_diffusion_backend.py new file mode 100644 index 0000000000..d70b4a2acb --- /dev/null +++ b/studio/backend/tests/test_diffusion_backend.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Unit tests for the diffusion image-generation backend. + +These tests cover the surface area the routes layer relies on: + +* family detection from the public Unsloth GGUF naming conventions +* generation argument validation (empty prompt, bad steps, off-grid sizes) +* base64 PNG encoding round-trips +* status() shape stays compatible with the frontend status poller +* load/unload lifecycle with the heavy diffusers import monkey-patched + +Real GPU loads are exercised manually via the Studio probe (see +``studio/backend/tests/test_diffusion_smoke.py``); here we keep the +suite CPU- and import-free so the consolidated CI job and the +``unslothai/unsloth`` CI fork can both run it on Ubuntu, macOS, and +Windows runners with no diffusion dependencies installed. +""" + +from __future__ import annotations + +import base64 +import io +import sys +import types +from typing import Any + +import pytest + + +# ── module under test ──────────────────────────────────────────── + + +@pytest.fixture(autouse = True) +def _reset_singleton(monkeypatch): + """Reset the module-level singleton between tests so each test + starts from a known state without poking globals directly.""" + import core.inference.diffusion as d + + monkeypatch.setattr(d, "_singleton", None) + yield + + +# ── family detection ──────────────────────────────────────────── + + +def test_detect_family_flux2_klein(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-klein-4B-GGUF") + assert fam is not None + assert fam.name == "flux.2-klein" + assert fam.pipeline_class == "Flux2KleinPipeline" + assert fam.transformer_class == "Flux2Transformer2DModel" + + +def test_detect_family_flux2_dev_is_not_klein(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-dev-GGUF") + assert fam is not None + assert fam.name == "flux.2" + # Critical: FLUX.2 dev must NOT pick up the FLUX.2 klein pipeline + # because the transformer architectures and text encoder + # configurations are different. + assert fam.pipeline_class == "Flux2Pipeline" + + +def test_detect_family_flux1(): + from core.inference.diffusion import detect_family + + fam = detect_family("city96/FLUX.1-dev-gguf") + assert fam is not None + assert fam.name == "flux.1" + assert fam.pipeline_class == "FluxPipeline" + + +def test_detect_family_qwen_image(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/Qwen-Image-GGUF") + assert fam is not None + assert fam.name == "qwen-image" + + +def test_detect_family_override_wins_over_substring(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-dev-GGUF", override_family = "flux.1") + assert fam is not None + assert fam.name == "flux.1" + + +def test_detect_family_override_unknown_returns_none(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-klein-4B-GGUF", override_family = "doesnotexist") + assert fam is None + + +def test_detect_family_unknown_returns_none(): + from core.inference.diffusion import detect_family + + assert detect_family("random/repo") is None + assert detect_family("") is None + + +def test_supported_families_payload_shape(): + from core.inference.diffusion import supported_families + + payload = supported_families() + assert isinstance(payload, list) + assert len(payload) >= 4 + for entry in payload: + assert set(entry.keys()) == {"name", "pipeline_class", "base_repo"} + + +# ── singleton ─────────────────────────────────────────────────── + + +def test_get_diffusion_backend_singleton(): + from core.inference.diffusion import get_diffusion_backend + + a = get_diffusion_backend() + b = get_diffusion_backend() + assert a is b + + +# ── status() shape ────────────────────────────────────────────── + + +def test_status_shape_unloaded(): + from core.inference.diffusion import get_diffusion_backend + + s = get_diffusion_backend().status() + expected_keys = { + "is_loaded", + "is_loading", + "repo_id", + "family", + "pipeline_class", + "base_repo", + "gguf_path", + "device", + "dtype", + "loaded_at", + "last_error", + "supported_families", + } + assert expected_keys.issubset(s.keys()) + assert s["is_loaded"] is False + assert s["repo_id"] is None + + +# ── encode_png_base64 ─────────────────────────────────────────── + + +def test_encode_png_base64_round_trip(): + from PIL import Image + + from core.inference.diffusion import encode_png_base64 + + img = Image.new("RGB", (16, 16), color = (255, 0, 0)) + b64 = encode_png_base64(img) + raw = base64.b64decode(b64) + decoded = Image.open(io.BytesIO(raw)) + assert decoded.format == "PNG" + assert decoded.size == (16, 16) + + +# ── generation validation (no real pipeline) ──────────────────── + + +def _stub_pipeline(monkeypatch, *, returns = None, raises = None): + """Mount a fake torch pipeline on the singleton so generate_image's + argument validation runs without diffusers / torch being involved.""" + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + + class _StubPipe: + def __call__(self, **kwargs): + if raises is not None: + raise raises + class _Out: + pass + o = _Out() + o.images = [returns or Image.new("RGB", (kwargs["width"], kwargs["height"]), color = (0, 255, 0))] + return o + + backend._pipe = _StubPipe() + backend._device = "cpu" + backend._family = d._FAMILIES[0] + backend._repo_id = "stub/stub" + return backend + + +def test_generate_image_rejects_empty_prompt(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "prompt is empty"): + backend.generate_image(prompt = " ") + + +def test_generate_image_rejects_bad_steps(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "num_inference_steps"): + backend.generate_image(prompt = "cat", num_inference_steps = 0) + with pytest.raises(ValueError, match = "num_inference_steps"): + backend.generate_image(prompt = "cat", num_inference_steps = 999) + + +def test_generate_image_rejects_off_grid_size(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "multiples of 8"): + backend.generate_image(prompt = "cat", width = 513, height = 512) + + +def test_generate_image_rejects_oversized(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "width and height"): + backend.generate_image(prompt = "cat", width = 4096, height = 512) + + +def test_generate_image_calls_pipeline_with_kwargs(monkeypatch): + backend = _stub_pipeline(monkeypatch) + img = backend.generate_image( + prompt = "a red sphere", + negative_prompt = "blue", + num_inference_steps = 4, + guidance_scale = 1.0, + width = 256, + height = 256, + seed = 42, + ) + assert img.size == (256, 256) + + +def test_generate_image_unloaded_raises(monkeypatch): + import core.inference.diffusion as d + + backend = d.get_diffusion_backend() + backend._pipe = None + with pytest.raises(RuntimeError, match = "No diffusion model"): + backend.generate_image(prompt = "x") + + +def test_unload_clears_state(monkeypatch): + backend = _stub_pipeline(monkeypatch) + assert backend.is_loaded + backend.unload_model() + assert not backend.is_loaded + s = backend.status() + assert s["repo_id"] is None + assert s["family"] is None + + +# ── load_model (with monkey-patched diffusers) ────────────────── + + +def _install_fake_diffusers(monkeypatch, *, raise_on_pipeline = False): + """Build a tiny ``diffusers`` shim so we can exercise load_model + without dragging the real 1+ GB diffusers / torch import in.""" + from PIL import Image + + fake = types.ModuleType("diffusers") + fake.__version__ = "fake" + + class _FakeQuantConfig: + def __init__(self, compute_dtype = None): + self.compute_dtype = compute_dtype + + class _FakeTransformer: + @classmethod + def from_single_file(cls, path, quantization_config = None, torch_dtype = None): + inst = cls() + inst.path = path + inst.qc = quantization_config + inst.dtype = torch_dtype + return inst + + class _FakePipeline: + @classmethod + def from_pretrained(cls, base_repo, **kwargs): + if raise_on_pipeline: + raise RuntimeError("simulated load failure") + inst = cls() + inst.base_repo = base_repo + inst.kwargs = kwargs + return inst + + def __call__(self, **kwargs): + class _Out: + pass + o = _Out() + o.images = [Image.new("RGB", (kwargs["width"], kwargs["height"]), color = (0, 0, 255))] + return o + + def enable_model_cpu_offload(self): + self.cpu_offload = True + + def to(self, device): + self.device = device + return self + + fake.GGUFQuantizationConfig = _FakeQuantConfig + fake.Flux2KleinPipeline = _FakePipeline + fake.Flux2Transformer2DModel = _FakeTransformer + fake.Flux2Pipeline = _FakePipeline + fake.FluxPipeline = _FakePipeline + fake.FluxTransformer2DModel = _FakeTransformer + fake.QwenImagePipeline = _FakePipeline + fake.QwenImageTransformer2DModel = _FakeTransformer + fake.SD3Transformer2DModel = _FakeTransformer + fake.StableDiffusion3Pipeline = _FakePipeline + fake.StableDiffusionXLPipeline = _FakePipeline + + monkeypatch.setitem(sys.modules, "diffusers", fake) + + # Pretend HF Hub gave us a local file without actually fetching. + fake_hub = types.ModuleType("huggingface_hub") + fake_hub.hf_hub_download = lambda repo_id, filename, token = None: f"/fake/{repo_id}/{filename}" + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub) + + # Force CPU dtype so the test does not need CUDA. + import core.inference.diffusion as d + + monkeypatch.setattr( + d.DiffusionBackend, + "_pick_device_and_dtype", + lambda self: ("cpu", "fake_dtype"), + ) + + return fake + + +def test_load_model_unknown_family(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + with pytest.raises(RuntimeError, match = "Could not infer"): + backend.load_model("private/random-repo") + + +def test_load_model_gguf_path_happy(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "FLUX.2-klein-4B-Q4_K_S.gguf", + ) + assert status["is_loaded"] is True + assert status["family"] == "flux.2-klein" + assert status["pipeline_class"] == "Flux2KleinPipeline" + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein" + assert status["gguf_path"] == ( + "/fake/unsloth/FLUX.2-klein-4B-GGUF/FLUX.2-klein-4B-Q4_K_S.gguf" + ) + + +def test_load_model_recovers_after_failure(monkeypatch): + _install_fake_diffusers(monkeypatch, raise_on_pipeline = True) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + with pytest.raises(RuntimeError, match = "Failed to load diffusion model"): + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "x.gguf", + ) + # Failed load must leave the singleton unloaded but with last_error set. + s = backend.status() + assert s["is_loaded"] is False + assert s["last_error"] and "simulated load failure" in s["last_error"] + + +def test_load_model_swap_drops_previous(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "FLUX.2-klein-4B-Q4_K_S.gguf", + ) + first_pipe = backend._pipe + backend.load_model( + "unsloth/FLUX.2-dev-GGUF", + gguf_filename = "FLUX.2-dev-Q4_K_S.gguf", + ) + assert backend._pipe is not first_pipe + assert backend.status()["family"] == "flux.2" diff --git a/studio/backend/tests/test_diffusion_routes.py b/studio/backend/tests/test_diffusion_routes.py new file mode 100644 index 0000000000..9b9063f0b1 --- /dev/null +++ b/studio/backend/tests/test_diffusion_routes.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Route-level tests for ``/api/inference/images/*``. + +Mounts the actual ``inference_router`` on a fresh FastAPI app with the +auth dependency replaced by a stub so we exercise the same FastAPI +handlers Studio ships in production. The diffusion backend is replaced +with an in-memory stub so we don't need diffusers / GPUs to run these. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from PIL import Image + + +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(_BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(_BACKEND_ROOT)) + + +class _FakeBackend: + def __init__(self) -> None: + self._loaded = False + self._repo: str | None = None + self.calls: list[dict] = [] + + @property + def is_loaded(self) -> bool: + return self._loaded + + def status(self) -> dict: + return { + "is_loaded": self._loaded, + "is_loading": False, + "repo_id": self._repo, + "family": "flux.2-klein" if self._loaded else None, + "pipeline_class": "Flux2KleinPipeline" if self._loaded else None, + "base_repo": "black-forest-labs/FLUX.2-klein" if self._loaded else None, + "gguf_path": None, + "device": "cpu", + "dtype": "torch.bfloat16", + "loaded_at": 0, + "last_error": None, + "supported_families": [], + } + + def load_model(self, repo_id, **kw): + self.calls.append({"op": "load", "repo_id": repo_id, **kw}) + self._loaded = True + self._repo = repo_id + return self.status() + + def unload_model(self) -> dict: + self._loaded = False + self._repo = None + return {"is_loaded": False} + + def generate_image(self, **kw): + self.calls.append({"op": "generate", **kw}) + return Image.new("RGB", (kw["width"], kw["height"]), color = (123, 45, 67)) + + +@pytest.fixture +def app_with_stub(monkeypatch): + """Build a FastAPI app that mounts the real inference router with + auth disabled and the diffusion backend swapped for a stub.""" + from routes import inference as inf + import core.inference.diffusion as d + + stub = _FakeBackend() + # Override the singleton accessor the route uses. + monkeypatch.setattr(d, "get_diffusion_backend", lambda: stub) + monkeypatch.setattr(inf, "_get_diffusion_backend", lambda: stub) + + app = FastAPI() + app.include_router(inf.router, prefix = "/api/inference") + # Bypass auth by overriding the dependency. + from auth.authentication import get_current_subject + + app.dependency_overrides[get_current_subject] = lambda: "test-user" + + return app, stub + + +def test_status_when_unloaded(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + r = c.get("/api/inference/images/status") + assert r.status_code == 200 + body = r.json() + assert body["is_loaded"] is False + assert body["repo_id"] is None + + +def test_generate_without_load_returns_400(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "a red sphere"}, + ) + assert r.status_code == 400 + assert "No diffusion model" in r.json()["detail"] + + +def test_load_then_generate_round_trip(app_with_stub): + app, stub = app_with_stub + c = TestClient(app) + + r = c.post( + "/api/inference/images/load", + json = { + "repo_id": "unsloth/FLUX.2-klein-4B-GGUF", + "gguf_filename": "FLUX.2-klein-4B-Q4_K_S.gguf", + }, + ) + assert r.status_code == 200, r.text + assert r.json()["is_loaded"] is True + + r = c.post( + "/api/inference/images/generate", + json = { + "prompt": "a tiny synth-pop album cover", + "width": 256, + "height": 256, + "num_inference_steps": 4, + "seed": 7, + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["image_b64"] + assert body["image_mime"] == "image/png" + assert body["width"] == 256 + assert body["height"] == 256 + assert body["seed"] == 7 + assert body["duration_ms"] >= 0 + + # Round-trip the base64 -> PIL to confirm it is a real PNG of the + # right size and not, say, an empty string. + import base64 + import io + + raw = base64.b64decode(body["image_b64"]) + decoded = Image.open(io.BytesIO(raw)) + assert decoded.format == "PNG" + assert decoded.size == (256, 256) + + # Backend stub should have recorded both calls. + ops = [c["op"] for c in stub.calls] + assert ops == ["load", "generate"] + + +def test_generate_rejects_off_grid_size(app_with_stub): + app, stub = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = { + "repo_id": "unsloth/FLUX.2-klein-4B-GGUF", + "gguf_filename": "x.gguf", + }, + ) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "x", "width": 513, "height": 512}, + ) + # Pydantic v2 wraps validator errors in 422 by default. + assert r.status_code in (400, 422), r.text + + +def test_unload_clears_state(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = {"repo_id": "unsloth/FLUX.2-klein-4B-GGUF", "gguf_filename": "x.gguf"}, + ) + r = c.post("/api/inference/images/unload") + assert r.status_code == 200 + assert r.json()["is_loaded"] is False + r = c.get("/api/inference/images/status") + assert r.json()["is_loaded"] is False diff --git a/studio/frontend/src/app/router.tsx b/studio/frontend/src/app/router.tsx index c7bc0440bd..b50f3fe618 100644 --- a/studio/frontend/src/app/router.tsx +++ b/studio/frontend/src/app/router.tsx @@ -9,6 +9,7 @@ import { Route as dataRecipeRoute } from "./routes/data-recipes.$recipeId"; import { Route as chatRoute } from "./routes/chat"; import { Route as exportRoute } from "./routes/export"; import { Route as gridTestRoute } from "./routes/grid-test"; +import { Route as imagesRoute } from "./routes/images"; import { Route as indexRoute } from "./routes/index"; import { Route as loginRoute } from "./routes/login"; import { Route as onboardingRoute } from "./routes/onboarding"; @@ -26,6 +27,7 @@ const routeTree = rootRoute.addChildren([ studioRoute, chatRoute, exportRoute, + imagesRoute, dataRecipesRoute, dataRecipeRoute, ]); diff --git a/studio/frontend/src/app/routes/images.tsx b/studio/frontend/src/app/routes/images.tsx new file mode 100644 index 0000000000..1761612140 --- /dev/null +++ b/studio/frontend/src/app/routes/images.tsx @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { createRoute } from "@tanstack/react-router"; +import { lazy } from "react"; +import { requireAuth } from "../auth-guards"; +import { Route as rootRoute } from "./__root"; + +const ImagesPage = lazy(() => + import("@/features/images").then((m) => ({ + default: m.ImagesPage, + })), +); + +export const Route = createRoute({ + getParentRoute: () => rootRoute, + path: "/images", + staticData: { title: "Images" }, + beforeLoad: () => requireAuth(), + component: ImagesPage, +}); diff --git a/studio/frontend/src/components/app-sidebar.tsx b/studio/frontend/src/components/app-sidebar.tsx index aac5f8f8a8..9a0830db2d 100644 --- a/studio/frontend/src/components/app-sidebar.tsx +++ b/studio/frontend/src/components/app-sidebar.tsx @@ -50,6 +50,7 @@ import { Globe02Icon, HelpCircleIcon, Logout01Icon, + PaintBrush02Icon, Search01Icon, PowerIcon, PencilEdit02Icon, @@ -497,6 +498,18 @@ export function AppSidebar() { }} /> + { + if (chatOnly) return; + navigate({ to: "/images" }); + closeMobileIfOpen(); + }} + /> + (res: Response): Promise { + if (!res.ok) throw new Error(await readFastApiError(res)); + return (await res.json()) as T; +} + +export async function fetchDiffusionStatus(): Promise { + return parseJson( + await authFetch("/api/inference/images/status"), + ); +} + +export async function loadDiffusionModel( + payload: DiffusionLoadRequest, +): Promise { + return parseJson( + await authFetch("/api/inference/images/load", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }), + ); +} + +export async function unloadDiffusionModel(): Promise<{ is_loaded: boolean }> { + return parseJson<{ is_loaded: boolean }>( + await authFetch("/api/inference/images/unload", { method: "POST" }), + ); +} + +export async function generateDiffusionImage( + payload: DiffusionGenerateRequest, +): Promise { + return parseJson( + await authFetch("/api/inference/images/generate", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }), + ); +} diff --git a/studio/frontend/src/features/images/images-page.tsx b/studio/frontend/src/features/images/images-page.tsx new file mode 100644 index 0000000000..3a408ba316 --- /dev/null +++ b/studio/frontend/src/features/images/images-page.tsx @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { SectionCard } from "@/components/section-card"; +import { Slider } from "@/components/ui/slider"; +import { Spinner } from "@/components/ui/spinner"; +import { Textarea } from "@/components/ui/textarea"; +import { toast } from "@/lib/toast"; +import { + fetchDiffusionStatus, + generateDiffusionImage, + loadDiffusionModel, + unloadDiffusionModel, + type DiffusionGenerateResponse, + type DiffusionStatus, +} from "./api"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; + +// Curated short list of working unsloth/* diffusion GGUFs. Picked to +// span size + license so any GPU class has at least one viable option: +// FLUX.2 klein 4B -> ~10-12 GB VRAM with Q4_K_S, Apache 2.0 +// FLUX.2 klein 9B -> ~16-18 GB VRAM, FLUX [klein] non-commercial +// FLUX.2 dev -> ~24+ GB VRAM, FLUX [dev] non-commercial +// The CLI on the backend can load anything supported by detect_family(); +// this list just keeps the picker compact for the v1 UI. +const CURATED_MODELS: Array<{ + label: string; + repo_id: string; + default_gguf: string; + family: string; + notes: string; +}> = [ + { + label: "FLUX.2 klein 4B (Q4_K_S, Apache 2.0)", + repo_id: "unsloth/FLUX.2-klein-4B-GGUF", + default_gguf: "FLUX.2-klein-4B-Q4_K_S.gguf", + family: "flux.2-klein", + notes: "13 GB VRAM, fastest. Apache 2.0.", + }, + { + label: "FLUX.2 klein 9B (Q4_K_S)", + repo_id: "unsloth/FLUX.2-klein-9B-GGUF", + default_gguf: "FLUX.2-klein-9B-Q4_K_S.gguf", + family: "flux.2-klein", + notes: "17 GB VRAM, higher quality.", + }, + { + label: "FLUX.2 dev (Q4_K_S)", + repo_id: "unsloth/FLUX.2-dev-GGUF", + default_gguf: "FLUX.2-dev-Q4_K_S.gguf", + family: "flux.2", + notes: "24+ GB VRAM, best for prompt following.", + }, + { + label: "FLUX.1 dev (Q4_K_S, city96)", + repo_id: "city96/FLUX.1-dev-gguf", + default_gguf: "flux1-dev-Q4_K_S.gguf", + family: "flux.1", + notes: "12 GB VRAM, older but well tested.", + }, +]; + +const DEFAULT_PRESET = CURATED_MODELS[0]; + +const RESOLUTION_PRESETS: Array<{ label: string; w: number; h: number }> = [ + { label: "Square 1024", w: 1024, h: 1024 }, + { label: "Square 768", w: 768, h: 768 }, + { label: "Square 512", w: 512, h: 512 }, + { label: "Portrait 832x1216", w: 832, h: 1216 }, + { label: "Landscape 1216x832", w: 1216, h: 832 }, +]; + +export function ImagesPage() { + const [status, setStatus] = useState(null); + const [refreshingStatus, setRefreshingStatus] = useState(false); + const [busy, setBusy] = useState<"idle" | "loading" | "unloading" | "generating">("idle"); + + const [presetIndex, setPresetIndex] = useState(0); + const [customRepoId, setCustomRepoId] = useState(""); + const [customGguf, setCustomGguf] = useState(""); + const [useCustom, setUseCustom] = useState(false); + const [hfToken, setHfToken] = useState(""); + + const [prompt, setPrompt] = useState("a tiny ginger sloth coding in a sunlit treehouse, photorealistic"); + const [negativePrompt, setNegativePrompt] = useState(""); + const [steps, setSteps] = useState(24); + const [guidance, setGuidance] = useState(3.5); + const [resolutionIdx, setResolutionIdx] = useState(0); + const [seed, setSeed] = useState(""); + + const [results, setResults] = useState([]); + const lastErrorRef = useRef(null); + + const preset = CURATED_MODELS[presetIndex] ?? DEFAULT_PRESET; + const resolution = RESOLUTION_PRESETS[resolutionIdx]; + + const refreshStatus = useCallback(async () => { + setRefreshingStatus(true); + try { + const next = await fetchDiffusionStatus(); + setStatus(next); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + if (lastErrorRef.current !== msg) { + lastErrorRef.current = msg; + toast.error("Could not fetch image-model status", { description: msg }); + } + } finally { + setRefreshingStatus(false); + } + }, []); + + useEffect(() => { + void refreshStatus(); + }, [refreshStatus]); + + const handleLoad = useCallback(async () => { + setBusy("loading"); + try { + const repo = useCustom ? customRepoId.trim() : preset.repo_id; + const gguf = useCustom ? customGguf.trim() || undefined : preset.default_gguf; + const family = useCustom ? undefined : preset.family; + if (!repo) { + toast.error("Pick a model first"); + return; + } + const next = await loadDiffusionModel({ + repo_id: repo, + gguf_filename: gguf, + family, + hf_token: hfToken.trim() || undefined, + }); + setStatus(next); + toast.success("Loaded image model", { description: next.repo_id ?? undefined }); + } catch (err) { + toast.error("Failed to load image model", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setBusy("idle"); + } + }, [useCustom, customRepoId, customGguf, preset, hfToken]); + + const handleUnload = useCallback(async () => { + setBusy("unloading"); + try { + await unloadDiffusionModel(); + await refreshStatus(); + } catch (err) { + toast.error("Failed to unload image model", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setBusy("idle"); + } + }, [refreshStatus]); + + const handleGenerate = useCallback(async () => { + if (!prompt.trim()) { + toast.error("Prompt is empty"); + return; + } + setBusy("generating"); + try { + const parsedSeed = seed.trim() ? Number(seed.trim()) : undefined; + if (parsedSeed !== undefined && !Number.isFinite(parsedSeed)) { + toast.error("Seed must be a number"); + return; + } + const out = await generateDiffusionImage({ + prompt, + negative_prompt: negativePrompt.trim() || undefined, + num_inference_steps: steps, + guidance_scale: guidance, + width: resolution.w, + height: resolution.h, + seed: parsedSeed, + }); + setResults((prev) => [out, ...prev].slice(0, 12)); + } catch (err) { + toast.error("Image generation failed", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setBusy("idle"); + } + }, [prompt, negativePrompt, steps, guidance, resolution, seed]); + + const statusLabel = useMemo(() => { + if (!status) return refreshingStatus ? "Checking..." : "Not loaded"; + if (status.is_loading) return "Loading..."; + if (status.is_loaded) { + const dev = status.device ? ` on ${status.device}` : ""; + return `Loaded: ${status.repo_id ?? "(unknown)"} (${status.family ?? "unknown"})${dev}`; + } + return "Not loaded"; + }, [status, refreshingStatus]); + + return ( +
+ +
+
+ + + {!useCustom && ( +

{preset.notes}

+ )} +
+ + {useCustom && ( +
+ + setCustomRepoId(e.target.value)} + placeholder="unsloth/FLUX.2-klein-4B-GGUF" + /> + + setCustomGguf(e.target.value)} + placeholder="FLUX.2-klein-4B-Q4_K_S.gguf" + /> +
+ )} + +
+ + setHfToken(e.target.value)} + placeholder="hf_..." + autoComplete="off" + /> +
+ +
+ + + + + {statusLabel} + +
+
+
+ + +
+
+ +