diff --git a/pyproject.toml b/pyproject.toml index aef88d90f5..6c37d50f80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,9 +81,30 @@ huggingfacenotorch = [ "datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0", "accelerate>=0.34.1", "peft>=0.18.0,!=0.11.0", + # Round 33 P1: reverted the round-26 hub>=1.3.0 floor. studio.txt + # forces hub==0.36.2 to match the transformers 4.57.6 pin in + # extras-no-deps.txt; the 1.3.0 floor here was internally + # inconsistent and reviewers reproduced the resolver conflict. + # Align with the colab-new extra's 0.34.0 floor (line 610). The + # transformers-5.x is_offline_mode concern that motivated the + # original bump never triggers because transformers is pinned at + # 4.57.6 on the supported install path. "huggingface_hub>=0.34.0", "hf_transfer", - "diffusers", + # Studio Images page depends on Flux2KleinPipeline / + # Flux2Pipeline, both shipped in diffusers>=0.37.0. Floor was + # missing here so a `pip install unsloth[huggingfacenotorch]` + # could resolve to 0.36.0 and fail at runtime when the default + # curated FLUX.2 klein model loads. + "diffusers>=0.37.0", + # diffusers.GGUFQuantizationConfig + from_single_file rely on + # the standalone gguf package at runtime. Floor at 0.10.0 to + # match the diffusers requirement; older gguf releases raise + # at load time. Studio Images default curated picker is + # GGUF-only so this must install with the public + # huggingfacenotorch extra; missing / under-pinned it makes + # /api/inference/images/load 500. + "gguf>=0.10.0", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.5.0", "trl>=0.18.2,!=0.19.0,<=0.24.0", "sentence-transformers", diff --git a/studio/backend/core/inference/diffusion.py b/studio/backend/core/inference/diffusion.py new file mode 100644 index 0000000000..c66a3690e5 --- /dev/null +++ b/studio/backend/core/inference/diffusion.py @@ -0,0 +1,2009 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +"""Diffusion image generation backend. + +Loads Hugging Face diffusion checkpoints in either the standard +``diffusers`` layout or the single-file GGUF layout published under +``unsloth/*-GGUF`` (Flux 2, Flux 2 Klein, Qwen-Image, SD3, SDXL, ...). +GGUF files are dynamically dequantised on-device via +``diffusers.GGUFQuantizationConfig``, then the rest of the pipeline +(VAE, text encoders, scheduler) is pulled from the matching ``diffusers`` +repo so end users only ever need one local file plus the metadata repo. + +The module is intentionally torch-only: it never spawns a subprocess and +shares the active CUDA / MPS device with the rest of Studio. The cost of +not having a separate process is that loading a diffusion model and a +GGUF chat model at the same time can OOM on consumer GPUs; the routes +layer must therefore swap between the two as needed (the orchestrator +unloads llama-server before any diffusion load on hosts with < 24 GB). + +The class deliberately exposes a small, llama-cpp-style surface: + + load_model(repo_id, ...) + generate_image(prompt, ...) -> PIL.Image + unload_model() + status() -> dict + +so the route layer at ``studio/backend/routes/inference.py`` can mirror +the existing llama-server lifecycle (probe + load + generate + unload) +without learning a second API. +""" + +from __future__ import annotations + +import asyncio +import gc +import io +import re +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path, PurePosixPath +from typing import Any, Optional + +from loggers import get_logger + +logger = get_logger(__name__) + + +# ─── Pipeline registry ──────────────────────────────────────────────── +# +# Keep this list narrow on purpose: only ship the small text-to-image +# families with first-class GGUF coverage on the Hub. Anything else is +# either video (LTX*, Wan) or research-grade (Sana, SD3.5) and can be +# added once it has a working GGUF release plus a smoke test. +# +# Each entry maps a substring of the loaded repo id (case-insensitive) +# to the (pipeline_class_name, transformer_class_name, default base +# repo for missing pieces). ``base_repo`` is what we pass to +# ``Pipeline.from_pretrained`` to pick up the VAE + text encoders when +# the user gave us a GGUF-only repo. The base_repo is documented to the +# user via ``status()`` so they understand why a second download fires. + + +@dataclass(frozen = True) +class DiffusionFamily: + name: str + pipeline_class: str + transformer_class: str + base_repo: str + # Optional: list of HF "trigger" substrings besides ``name`` that map + # to this family (e.g. "flux1-dev" plus "flux.1-dev"). Lowercased. + aliases: tuple[str, ...] = field(default_factory = tuple) + + +_FAMILIES: tuple[DiffusionFamily, ...] = ( + # The "9b" alias is checked first so a "flux-2-klein-9b" GGUF picks + # the 9B base instead of the 4B one when the user does not pass an + # explicit base_repo. Apache 2.0 is preferred as the auto-default for + # the 4B path because BFL's 9B base is gated. + DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2Transformer2DModel", + # Default for klein when no explicit base_repo: Apache-2.0 4B Base. + # The frontend curated picker always passes base_repo explicitly, + # so this default only fires for "custom HF repo" mode. + base_repo = "black-forest-labs/FLUX.2-klein-base-4B", + aliases = ("flux2-klein", "flux-2-klein", "flux.2.klein"), + ), + DiffusionFamily( + name = "flux.2", + pipeline_class = "Flux2Pipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-dev", + aliases = ("flux2-dev", "flux-2-dev", "flux.2.dev"), + ), + DiffusionFamily( + name = "flux.1", + pipeline_class = "FluxPipeline", + transformer_class = "FluxTransformer2DModel", + base_repo = "black-forest-labs/FLUX.1-dev", + aliases = ("flux1-dev", "flux-1-dev", "flux.1.dev", "flux-dev"), + ), + DiffusionFamily( + name = "qwen-image", + pipeline_class = "QwenImagePipeline", + transformer_class = "QwenImageTransformer2DModel", + base_repo = "Qwen/Qwen-Image", + aliases = ("qwenimage", "qwen_image"), + ), + DiffusionFamily( + name = "stable-diffusion-3", + pipeline_class = "StableDiffusion3Pipeline", + transformer_class = "SD3Transformer2DModel", + base_repo = "stabilityai/stable-diffusion-3-medium-diffusers", + # Intentionally NOT including "sd3.5" / "stable-diffusion-3.5" + # here: the SD3.5 family uses a different transformer config and + # base repo than SD3 Medium, and silently pairing SD3.5 GGUFs + # with the Medium base produces a misleading load. Add a + # dedicated SD3.5 family with its own base_repo when we ship + # smoke coverage for it. + aliases = ("sd3-medium", "stable-diffusion-3-medium"), + ), + # SDXL: full diffusers path only (no GGUF). SDXL uses a UNet (not a + # transformer) and wiring UNet2DConditionModel.from_single_file + + # GGUF is a separate code path the rest of this module does not + # exercise. The family is intentionally NOT in _FAMILIES so the + # frontend status panel does not advertise GGUF support we do not + # implement; callers wanting SDXL full repos can still do so by + # passing the diffusers repo with no gguf_filename and + # family_override = "stable-diffusion-xl" via the route, which uses + # the lookup in _FULL_REPO_FAMILIES. +) + + +# Families available via family_override on the routes layer when the +# user is loading a full diffusers checkpoint (no GGUF). Kept separate +# from _FAMILIES so the GGUF-only status panel does not over-advertise. +_FULL_REPO_FAMILIES: tuple[DiffusionFamily, ...] = ( + DiffusionFamily( + name = "stable-diffusion-xl", + pipeline_class = "StableDiffusionXLPipeline", + transformer_class = "", + base_repo = "stabilityai/stable-diffusion-xl-base-1.0", + aliases = ("sdxl",), + ), +) + + +def _smart_base_repo(fam: DiffusionFamily, repo_id: str) -> str: + """Pick the best matching base diffusers repo for a given GGUF repo + when the caller did not pass an explicit base_repo. + + Currently only specialises the flux.2-klein family: a repo name + containing "9b" gets the 9B base, "base-4b" / "base-9b" map to the + Base variants, everything else falls back to the family default + (Apache 2.0 4B Base). + + Only the LAST segment of the repo id / path is inspected so a + namespace or parent directory like ``baseorg/...`` or + ``/home/me/.cache/base/...`` does not falsely select the Base + variant (round 12 review #9). Splits on BOTH ``/`` and ``\\`` so + Windows local paths like ``C:\\Users\\me\\base\\FLUX.2-klein-4B`` + do not get scored as "base" via the parent directory either + (round 13 P2 #13). + """ + if fam.name != "flux.2-klein": + return fam.base_repo + cleaned = (repo_id or "").rstrip("/\\") + last_segment = re.split(r"[\\/]+", cleaned)[-1].lower() if cleaned else "" + is_9b = "9b" in last_segment + is_base = "base" in last_segment + if is_9b and is_base: + return "black-forest-labs/FLUX.2-klein-base-9B" + if is_9b: + return "black-forest-labs/FLUX.2-klein-9B" + if is_base: + return "black-forest-labs/FLUX.2-klein-base-4B" + # Distilled 4B is the default for any flux-2-klein GGUF that does + # not advertise 9B or "base". + return "black-forest-labs/FLUX.2-klein-4B" + + +def _expand_existing_local_path(value: str) -> str: + """Expand ``~`` in ``value`` when the expanded path exists locally. + + Round 14 P2 #11: the GGUF local path branch already calls + ``Path(repo_id).expanduser()``, but the full-diffusers-repo and + base-companion-repo paths passed the literal ``~/...`` straight + into ``from_pretrained``, which treated it as a Hub id and tried + to download. Keep behaviour identical for Hub ids (no leading + ``~`` -> return as-is) and for non-existent expansions (the + diffusers loader will surface its own ``not found`` error). + """ + if not value or not isinstance(value, str) or not value.startswith("~"): + return value + candidate = Path(value).expanduser() + if candidate.exists(): + return str(candidate) + return value + + +def _preflight_diffusers_subfolder_config( + repo: str, + subfolder: str, + hf_token: Optional[str], +) -> None: + """Round 21 P2 #6: also probe ``{subfolder}/config.json``. + + The full-repo preflight at ``_preflight_full_diffusers_repo`` + only proves ``model_index.json`` exists. For GGUF loads the + follow-up ``from_single_file(..., config=effective_base, + subfolder="transformer")`` still needs a matching + ``transformer/config.json`` on the base companion. Without + this second probe a base that has model_index.json but no + transformer config would still unload chat before the load + failed. + """ + if not repo or not subfolder: + return + try: + local = Path(repo).expanduser() + except (OSError, ValueError): + local = None + if local is not None and local.exists(): + config_path = local / subfolder / "config.json" + if not config_path.is_file(): + raise RuntimeError( + f"Diffusion repo '{_display_repo_id(repo)}' is missing " + f"{subfolder}/config.json." + ) + return + if (local is not None and local.is_absolute()) or repo.startswith("~"): + # Local-only path that does not exist -- _preflight_full_diffusers_repo + # already raised for the absent directory, so reaching here means the + # caller is loading a Hub id that just looks like a path. Fall through + # to the network probe. + pass + try: + from huggingface_hub import hf_hub_download as _hf_hub_download + except Exception: + return + try: + _hf_hub_download( + repo_id = repo, + filename = "config.json", + subfolder = subfolder, + token = hf_token, + ) + except Exception as exc: + raise RuntimeError( + f"Could not access diffusion repo '{_display_repo_id(repo)}' " + f"{subfolder}/config.json before unloading the current model." + ) from exc + + +def _preflight_full_diffusers_repo(repo: str, hf_token: Optional[str]) -> None: + """Prove a full diffusers repo is accessible before any unloads. + + Round 19 P1 #3: the GGUF path's ``hf_hub_download(gguf_filename)`` + above this function fails fast on a bad / private / gated / + typo'd repo before we touch the chat backend. The full diffusers + path used to skip that round-trip and only discover the issue + inside ``from_pretrained`` AFTER the user's chat model was + already unloaded. Add the same one-file probe (``model_index.json`` + is the diffusers manifest; every diffusers repo has one). + + Local paths are checked structurally so we do not hit the network + for a missing on-disk directory; both branches raise RuntimeError + so the surrounding load_model bails out before the chat unload. + The display label is collapsed via ``_display_repo_id`` so an + absolute filesystem path in the error message does not leak the + operator's layout (see round 17 P2 #9). + """ + if not repo: + return + try: + local = Path(repo).expanduser() + except (OSError, ValueError): + local = None + if local is not None and local.exists(): + if not local.is_dir(): + raise RuntimeError( + f"Diffusion repo '{_display_repo_id(repo)}' is not a directory." + ) + if not (local / "model_index.json").is_file(): + raise RuntimeError( + f"Diffusion repo '{_display_repo_id(repo)}' is missing " + "model_index.json." + ) + return + if (local is not None and local.is_absolute()) or repo.startswith("~"): + raise RuntimeError( + f"Local diffusion repo '{_display_repo_id(repo)}' does not exist." + ) + try: + from huggingface_hub import hf_hub_download as _hf_hub_download + except Exception: + # diffusers is installed but huggingface_hub is missing -- let + # the downstream loader produce the canonical error. + return + try: + _hf_hub_download( + repo_id = repo, + filename = "model_index.json", + token = hf_token, + ) + except Exception as exc: + raise RuntimeError( + f"Could not access diffusion repo '{_display_repo_id(repo)}' " + "before unloading the current model." + ) from exc + + +def _display_repo_id(value: Any) -> Any: + """Return a public-facing label for a repo_id / base_repo. + + For Hub-style identifiers (``owner/repo``) the value passes + through unchanged so the Images panel and result figcaption + stay informative. Absolute local paths (``/home/me/exports/...`` + or ``C:\\Users\\...``) collapse to the leaf name so + ``/images/status`` does not leak the user's filesystem layout + to other authenticated browser sessions (round 15 P2 #6). HF + tokens are scrubbed defensively in case they slipped past the + request-side validator. + """ + if not isinstance(value, str) or not value: + return value + try: + candidate = Path(value).expanduser() + if candidate.is_absolute() or candidate.exists(): + # Defense-in-depth: redact any hf_... pattern that survives + # in the leaf name before returning it to the UI / log line. + return _redact_hf_tokens(candidate.name or value) + except (OSError, ValueError): + pass + return _redact_hf_tokens(value) + + +_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}") + + +def _redact_hf_tokens(value: Any) -> Any: + """Scrub embedded ``hf_xxxxxxxx`` tokens out of a string before + logging. Round 14 P2 #9: callers can wrap an authenticated URL + (``https://hf_token@huggingface.co/...``) into ``repo_id`` / + ``base_repo`` / paths; the token would otherwise reach + structured-log sinks via the load-info / load-failure log lines. + Non-strings are returned unchanged so the helper is safe to + sprinkle through ``logger.info`` / ``logger.error`` argument + lists. + """ + if not isinstance(value, str): + return value + return _HF_TOKEN_RE.sub("", value) + + +def _resolve_local_gguf_child(repo_root: Path, gguf_filename: str) -> Path: + """Resolve a GGUF filename inside a local repo directory safely. + + Returns the resolved absolute path or raises ``RuntimeError`` if: + - ``gguf_filename`` is absolute (``/etc/passwd``) or contains a + Windows separator (``..\\..\\secret.gguf``); + - the parts contain ``""`` / ``.`` / ``..`` (``../other.gguf``); + - the resolved candidate escapes ``repo_root`` after symlinks / + ``..`` collapse; + - the resolved candidate is not a regular file. + + This is the only path that bridges a user-supplied ``gguf_filename`` + string into ``Path``s the loader opens, so confining it to the + chosen repo here protects the delete-ownership guards downstream + (round 13 P1 #2). ``hf_hub_download`` already enforces the same + invariant for Hub repos. + """ + # ``Path("/etc/passwd").is_absolute()`` is False on Windows (POSIX + # absolute paths read as drive-relative), so check both pathlib + # flavours plus a leading separator so the rejection is portable. + if ( + Path(gguf_filename).is_absolute() + or PurePosixPath(gguf_filename).is_absolute() + or gguf_filename.startswith(("/", "\\")) + or "\\" in gguf_filename + ): + raise RuntimeError("gguf_filename must be a relative file path inside repo_id.") + rel = PurePosixPath(gguf_filename) + if any(part in ("", ".", "..") for part in rel.parts): + raise RuntimeError( + "gguf_filename must not contain empty, '.', or '..' segments." + ) + root = repo_root.expanduser().resolve(strict = True) + try: + candidate = (root / Path(*rel.parts)).resolve(strict = True) + except (OSError, FileNotFoundError) as exc: + # strict=True raises FileNotFoundError on a missing leaf or + # parent component, and OSError on a malformed Windows path + # (e.g. drive letters injected through the user-supplied + # string). Either way the candidate does not exist inside the + # chosen repo, which is exactly the "file not in repo" failure + # mode the caller cares about. + raise RuntimeError( + f"Local repo path '{repo_root}' does not contain '{gguf_filename}'." + ) from exc + try: + candidate.relative_to(root) + except ValueError as exc: + raise RuntimeError( + "gguf_filename must stay inside the local repo_id directory." + ) from exc + if not candidate.is_file(): + raise RuntimeError( + f"Local repo path '{repo_root}' does not contain '{gguf_filename}'." + ) + return candidate + + +# Negative substrings that disqualify a candidate family even when its +# name appears as a substring of the repo id. Prevents +# "stable-diffusion-3" matching SD3.5 and "qwen-image" matching +# Qwen-Image-Edit. Each entry maps a family name to substrings that +# must NOT appear anywhere in the repo id. +_FAMILY_EXCLUDE: dict[str, tuple[str, ...]] = { + "stable-diffusion-3": ( + "3.5", + "3-5", + "3_5", + "stable-diffusion-3.5", + "stable_diffusion_3_5", + ), + # All underscore / hyphen spellings that appear in Hub repo ids for + # the *-Edit family must exclude Qwen-Image, otherwise + # ``unsloth/qwen_image_edit-GGUF`` matches the Qwen-Image base. + "qwen-image": ( + "qwen-image-edit", + "qwenimage-edit", + "qwen_image_edit", + "qwenimageedit", + ), +} + + +def detect_family( + repo_id: str, *, override_family: Optional[str] = None +) -> Optional[DiffusionFamily]: + """Return the diffusion family matching ``repo_id``. + + Matching is substring-based and case-insensitive, with a small + deny list (``_FAMILY_EXCLUDE``) for known false positives such as + SD3.5 (would otherwise match SD3 Medium) and Qwen-Image-Edit + (would otherwise match Qwen-Image). ``override_family`` bypasses + substring matching and looks up by ``DiffusionFamily.name`` or + (when explicitly asked) by ``_FULL_REPO_FAMILIES.name``. Returns + ``None`` when no family applies so callers can surface a clear + "unsupported model" error rather than guessing wrong. + """ + if override_family: + wanted = override_family.strip().lower() + for fam in _FAMILIES + _FULL_REPO_FAMILIES: + if fam.name == wanted: + return fam + return None + needle = (repo_id or "").lower() + if not needle: + return None + # Round 17 P2 #10: if repo_id is an absolute local path, the + # whole path goes into ``needle`` and the _FAMILY_EXCLUDE deny + # lists match against parent-directory names too. That means + # ``/home/me/qwen-image-edit-cache/flux-2-klein-4b`` would be + # excluded from the Flux family because the parent contains + # ``qwen-image-edit``. Reduce to the leaf when the candidate + # looks like a filesystem path so excludes only consider the + # model directory itself. + if "/" in needle or "\\" in needle: + try: + candidate = Path(repo_id).expanduser() + if candidate.is_absolute() or candidate.exists(): + leaf = candidate.name + if leaf: + needle = leaf.lower() + except (OSError, ValueError): + pass + # Normalise mixed separator spellings (``Qwen_Image-Edit-GGUF``, + # ``Qwen-Image_Edit-GGUF``, ``Qwen.Image.Edit-GGUF``) and the + # compact concatenation (``QwenImageEdit-GGUF``) so the + # _FAMILY_EXCLUDE deny lists do not need every permutation of + # ``-``, ``_``, ``.`` and run-together spellings to keep + # Qwen-Image-Edit out of the base Qwen-Image family (round 14 + # P2 #8). + needle_norm = re.sub(r"[^a-z0-9]+", "-", needle).strip("-") + needle_compact = re.sub(r"[^a-z0-9]+", "", needle) + # Per-token compact strings let ``unsloth/Flux2Klein-GGUF`` match + # the ``flux2klein`` alias: the whole-needle compact is + # ``unslothflux2kleingguf`` and the regex boundary check rejects + # the embedded match, but the token ``Flux2Klein`` (between the + # ``/`` and the ``-``) compacts to exactly ``flux2klein`` (round + # 16 P2 #9). + needle_compact_tokens = { + re.sub(r"[^a-z0-9]+", "", token) + for token in re.split(r"[^a-z0-9]+", needle) + if token + } + + def _matches_family_token(term: str) -> bool: + """Token-boundary match on the normalised needle. Prevents + ``owner/flux.20-model`` from matching ``flux.2`` because + ``flux.20`` does not have a separator after ``flux-2`` + (round 15 P2 #8). Compact spellings (``flux2klein``) match + only when they appear as a complete repo-name token, not + as a substring of a longer token (round 16 P2 #9).""" + term_norm = re.sub(r"[^a-z0-9]+", "-", term.lower()).strip("-") + if not term_norm: + return False + if re.search(rf"(^|-){re.escape(term_norm)}($|-)", needle_norm): + return True + term_compact = re.sub(r"[^a-z0-9]+", "", term.lower()) + if not term_compact: + return False + return term_compact in needle_compact_tokens or term_compact == needle_compact + + # Scan _FAMILIES first (GGUF-supported), then _FULL_REPO_FAMILIES + # so a repo like ``stabilityai/stable-diffusion-xl-base-1.0`` is + # auto-detected as SDXL instead of returning None. + for fam in _FAMILIES + _FULL_REPO_FAMILIES: + excludes = _FAMILY_EXCLUDE.get(fam.name, ()) + if any( + e in needle + or re.sub(r"[^a-z0-9]+", "-", e).strip("-") in needle_norm + or re.sub(r"[^a-z0-9]+", "", e) in needle_compact + for e in excludes + ): + continue + if _matches_family_token(fam.name): + return fam + for alias in fam.aliases: + if alias and _matches_family_token(alias): + return fam + return None + + +def supported_families() -> list[dict[str, str]]: + """Public-facing list of families for ``/api/inference/images/status``.""" + return [ + { + "name": fam.name, + "pipeline_class": fam.pipeline_class, + "base_repo": fam.base_repo, + } + for fam in _FAMILIES + ] + + +# ─── Backend ────────────────────────────────────────────────────────── + + +class DiffusionBackend: + """Singleton-style diffusion backend. + + One pipeline at a time; ``load_model`` swaps the previous one out. + Generation is mutex'd so concurrent requests serialise rather than + racing GPU memory. + """ + + def __init__(self) -> None: + self._pipe: Any = None + # `_lock` protects mutations to the small state fields and is + # the only lock taken by status(). It is intentionally NOT held + # for the long pipeline forward pass: holding it for the whole + # generate would block status() polls (frontend at 1 Hz) and + # any concurrent unload requests for minutes at a time. + # + # `_load_lock` serialises the entire load_model call so two + # concurrent /images/load requests cannot both reach + # pipeline_cls.from_pretrained at the same time (which would + # double-spend VRAM and corrupt _pipe). + # + # `_generate_lock` serialises pipeline __call__ since diffusers + # pipelines are not thread-safe; overlapping forwards on the + # shared pipe corrupt internal scheduler state. + # + # Lock order is load -> state and generate -> state (never + # state -> load/generate) so a forward in flight cannot + # deadlock the next load or a status poll. + self._lock = threading.Lock() + self._load_lock = threading.Lock() + self._generate_lock = threading.Lock() + self._family: Optional[DiffusionFamily] = None + self._repo_id: Optional[str] = None + self._gguf_path: Optional[str] = None + # Original ``gguf_filename`` the caller passed in, preserved + # so delete guards can compare against subdirectory variants + # like ``BF16/model.gguf`` or ``Q4_K_M/model.gguf`` instead + # of the collapsed basename (round 14 P1 #4). The basename + # alone (``model.gguf``) loses the quant directory and lets + # /delete-cached unlink the wrong file. + self._gguf_filename: Optional[str] = None + self._base_repo: Optional[str] = None + self._device: Optional[str] = None + self._dtype: Optional[str] = None + # True when ``enable_model_cpu_offload()`` was applied on the + # loaded pipeline. Diffusers' offload moves the active + # submodule between CPU and GPU on each step, so a CUDA + # ``torch.Generator`` mismatches the CPU-resident embeddings + # and generation crashes mid-forward (round 14 P1 #6). When + # this is True, seeded generation has to use a CPU generator + # regardless of self._device. + self._cpu_offload_enabled: bool = False + self._loaded_at: Optional[float] = None + self._loading: bool = False + self._last_error: Optional[str] = None + # `_pending_*` fields advertise the target of an in-flight load + # so cache- and finetuned-delete guards can refuse to rmtree a + # repo while it is being downloaded / read. They are set under + # _lock at the start of load_model and cleared on success or + # in the finally block. The route layer reads them via + # status() under _lock. + self._pending_repo_id: Optional[str] = None + self._pending_base_repo: Optional[str] = None + self._pending_gguf_filename: Optional[str] = None + + # ── lifecycle ───────────────────────────────────────────────── + + @property + def is_loaded(self) -> bool: + return self._pipe is not None + + @property + def repo_id(self) -> Optional[str]: + return self._repo_id + + def status(self, *, include_internal: bool = False) -> dict[str, Any]: + # Take _lock so the snapshot cannot observe a torn state where + # _pipe was already swapped but _family/_repo_id haven't been + # updated yet (or vice versa). Frontend polling at 1 Hz would + # otherwise render impossible "loaded but no repo_id" states. + # Only echo the GGUF basename; full absolute path leaks the + # local HF cache layout (and the system username on default + # POSIX layouts) to any authenticated Studio session. + # + # Round 16 P1 #5: the guard-facing ``active_*`` / ``pending_*`` + # fields hold the EXACT raw path (so /delete-cached can match + # an HF snapshot mmap) but are NOT safe to surface to the + # browser. Callers that need the raw path (route-internal + # delete guards) pass ``include_internal=True``; the public + # ``/api/inference/images/status`` route always uses the + # public payload. + with self._lock: + # Expose BOTH the resident pipeline's id AND the pending + # load target. Delete guards must check both: when model A + # is already loaded and a swap to model B is in flight, + # only checking one would let the user rmtree whichever + # repo the guard ignored. UI-facing ``repo_id`` / + # ``base_repo`` / ``gguf_filename`` still prefer pending + # during a swap so the panel shows the load target the + # user just clicked. + active_repo = self._repo_id + active_base = self._base_repo + active_gguf = self._gguf_filename + pending_repo = self._pending_repo_id if self._loading else None + pending_base = self._pending_base_repo if self._loading else None + pending_gguf = self._pending_gguf_filename if self._loading else None + # When a swap is in flight, the UI-facing repo_id / + # base_repo / gguf_filename advertise the PENDING model + # but ``self._family`` still points at the previously + # loaded pipeline. Reporting them together produces a + # repo/family pair that never existed (round 11 #6). + # Null the family / pipeline_class while a swap is in + # flight; the frontend can fall back to "unknown". + ui_family = self._family.name if self._family else None + ui_pipeline_class = self._family.pipeline_class if self._family else None + if pending_repo and pending_repo != active_repo: + ui_family = None + ui_pipeline_class = None + # UI-facing ``gguf_filename`` collapses to the basename + # so the Images panel does not surface internal cache / + # variant directory names. Guard-facing ``active_*`` / + # ``pending_*`` retain the full caller-supplied filename + # so /delete-cached can compare against subdirectory + # variants like ``BF16/model.gguf`` (round 14 P1 #4-5). + ui_gguf = pending_gguf or active_gguf + ui_gguf_basename = Path(ui_gguf).name if ui_gguf else None + # UI-facing ``repo_id`` / ``base_repo`` collapse absolute + # local paths to their leaf name so ``/images/status`` + # does not leak the user's filesystem layout to other + # authenticated browser sessions (round 15 P2 #6). The + # guard-facing ``active_*`` / ``pending_*`` fields below + # preserve the exact value so delete guards still match + # against the snapshot path. + payload: dict[str, Any] = { + "is_loaded": self._pipe is not None, + "is_loading": self._loading, + "repo_id": _display_repo_id(pending_repo or active_repo), + "family": ui_family, + "pipeline_class": ui_pipeline_class, + "base_repo": _display_repo_id(pending_base or active_base), + "gguf_filename": ui_gguf_basename, + "device": self._device, + "dtype": self._dtype, + "loaded_at": self._loaded_at, + "last_error": self._last_error, + "supported_families": supported_families(), + } + if include_internal: + # Guard-facing fields: every repo / path / GGUF + # filename the backend owns RIGHT NOW. Delete routes + # iterate both, paired so the variant-filename check + # is compared against the SAME repo that owns it + # (round 13 P1 #3-5). Round 16 P1 #5: never returned + # by the public /images/status route. + payload.update( + { + "active_repo_id": active_repo, + "active_base_repo": active_base, + "active_gguf_filename": active_gguf, + "pending_repo_id": pending_repo, + "pending_base_repo": pending_base, + "pending_gguf_filename": pending_gguf, + } + ) + return payload + + def _pick_device_and_dtype(self) -> tuple[str, "Any"]: + """Pick (device, dtype) for the current host. + + CUDA-first because that is the only path our diffusion GGUFs are + validated on. On macOS we use MPS in float16 to keep the pipeline + on the Metal GPU. CPU is allowed only as a last resort because + running FLUX on CPU is unusably slow (> 10 minutes per image). + + BF16 is gated on ``torch.cuda.is_bf16_supported`` because the + Pascal / Turing class (sm_60 / sm_70 / sm_75) reports + ``is_available() == True`` but lacks BF16 ALUs; FLUX kernels + then fail inside ``from_pretrained`` or at the first denoise + step. Those cards still work on FP16, so fall back rather than + refuse to load. + """ + import torch + + if torch.cuda.is_available(): + bf16_ok = False + try: + bf16_ok = bool(torch.cuda.is_bf16_supported()) + except Exception: + bf16_ok = False + return "cuda", torch.bfloat16 if bf16_ok else torch.float16 + if ( + hasattr(torch, "backends") + and getattr(torch.backends, "mps", None) + and torch.backends.mps.is_available() + ): + return "mps", torch.float16 + return "cpu", torch.float32 + + def load_model( + self, + repo_id: str, + *, + gguf_filename: Optional[str] = None, + base_repo: Optional[str] = None, + hf_token: Optional[str] = None, + family_override: Optional[str] = None, + enable_model_cpu_offload: bool = True, + ignore_public_load_pending_workload: Optional[str] = None, + ) -> dict[str, Any]: + """Load a diffusion model. + + ``repo_id`` is the Hugging Face repo id of either a GGUF-only + repo (e.g. ``unsloth/FLUX.2-klein-4B-GGUF``) or a full diffusers + repo (e.g. ``black-forest-labs/FLUX.2-klein``). When the repo + contains a GGUF, ``gguf_filename`` picks which quant to load; + otherwise diffusers' standard config-driven load runs. + + ``base_repo`` overrides the auto-detected diffusers base used + for VAE / text encoders. ``family_override`` short-circuits the + substring matcher when an exotic repo name confuses it. + + Raises ``RuntimeError`` on failure with a user-facing message. + On a failed swap the previous pipeline is also released to + keep peak VRAM bounded; status() reports is_loaded=false with + last_error set so the caller can react. + """ + # Surface a friendly load error when the no-torch / partial + # install path is active: the user clicked Load on the Images + # page but the runtime never installed torch + diffusers (round + # 13 P2 #12). Without this wrapper the import surfaces as a + # raw ``ModuleNotFoundError`` -> 500 instead of a 400 the UI + # can display. + try: + from huggingface_hub import hf_hub_download + import diffusers + import torch + except ModuleNotFoundError as exc: + missing = exc.name or str(exc) + raise RuntimeError( + "Diffusion image generation requires the torch / diffusers " + f"runtime. Missing dependency: {missing}. Install the Studio " + "torch runtime (re-run setup.sh / install.ps1) before " + "loading an image model." + ) from exc + + # Round 30 P1 #11: also preflight transformers BEFORE any + # destructive unload. Diffusers can expose stub pipeline + # classes when transformers is missing or broken, so the load + # would otherwise tear down chat first and fail later inside + # from_pretrained. Use find_spec (no module execution) so test + # environments that stub these modules still pass the preflight + # without us actually importing them. + # Round 34: accelerate is pulled in transitively by every + # supported transformers install path (it is a hard runtime + # dep of transformers' PyTorch backend), so a separate + # find_spec("accelerate") guard is redundant in practice and + # broke the CI test matrix where the test env ships + # transformers without accelerate. The offload code path + # (``enable_model_cpu_offload`` / ``device_map="auto"``) + # will surface a clean ModuleNotFoundError if a user somehow + # arrives at an offload-needed load without it. + import importlib.util as _ilu + + if _ilu.find_spec("transformers") is None: + raise RuntimeError( + "Diffusion image generation requires the Studio torch " + "runtime. Missing dependency: transformers. Install the " + "Studio torch runtime (re-run setup.sh / install.ps1) " + "before loading an image model." + ) + + fam = detect_family(repo_id, override_family = family_override) + if fam is None: + # Round 22 P2 #4: route the repo label through + # ``_display_repo_id`` so a local absolute path that did + # not match any family does not leak the operator's + # filesystem layout via the error message / last_error + # / 400 response body. + raise RuntimeError( + f"Could not infer a diffusion family for '{_display_repo_id(repo_id)}'. " + "Pass family_override = 'flux.2-klein' / 'flux.2' / " + "'flux.1' / 'qwen-image' / 'stable-diffusion-3' / " + "'stable-diffusion-xl' to disambiguate." + ) + + device, dtype = self._pick_device_and_dtype() + + # Round 32 P1 #3: track whether the backend-side + # helper-busy check published a "diffusion-backend" pending + # entry so the outer finally clears the matching publish + # exactly once. Set inside the try below right after the + # snapshot succeeds. + backend_pending_published = False + + # _load_lock serialises the entire load so two concurrent calls + # cannot both kick off a multi-GB download + GPU upload at once. + # The second caller waits behind the first and then loads on top + # of the now-populated state via the normal swap path. + # _generate_lock is also taken so we do not start swapping the + # pipeline (release old + allocate new) while a previous + # generation is still iterating denoising steps; releasing the + # pipe out from under an in-flight forward corrupts scheduler + # state. Order: _load_lock -> _generate_lock -> _lock so a + # forward (which only takes _generate_lock + briefly _lock) + # cannot block a queued load forever. + with self._load_lock, self._generate_lock: + with self._lock: + self._loading = True + self._last_error = None + # Publish the pending target so cache / finetuned + # delete guards can see what is mid-download even + # before _repo_id / _base_repo are populated on + # success. + self._pending_repo_id = repo_id + self._pending_base_repo = base_repo + # Store the caller's full ``gguf_filename`` (e.g. + # ``BF16/model.gguf``) so the variant-aware delete + # guards have the subdirectory info. The UI side of + # status() still collapses to the basename for display. + self._pending_gguf_filename = gguf_filename if gguf_filename else None + try: + pipeline_cls = getattr(diffusers, fam.pipeline_class, None) + if pipeline_cls is None: + raise RuntimeError( + f"diffusers {diffusers.__version__} has no " + f"{fam.pipeline_class}; upgrade diffusers and retry." + ) + transformer_cls = ( + getattr(diffusers, fam.transformer_class, None) + if fam.transformer_class + else None + ) + + # Resolution rules for the "what repo to call + # from_pretrained on" question: + # 1. no GGUF file -> caller is loading a full + # diffusers repo; use repo_id directly so we do + # not silently substitute the family default + # AND ignore any base_repo input (it is only + # meaningful as a GGUF companion override). The + # old order let ``base_repo`` swap a fine-tuned + # ``owner/my-flux.1-finetune`` for + # ``black-forest-labs/FLUX.1-dev`` while status + # still advertised the user's repo (round 13 + # P2 #10). + # 2. otherwise prefer caller-supplied base_repo for + # the missing VAE / text encoder components. + # 3. otherwise use the family + repo_id heuristic so + # a 9B GGUF picks the 9B base, not the 4B fallback. + if not gguf_filename: + # Guard: a repo that ends in "-GGUF" (the unsloth + # convention) is GGUF-only and will 500 on + # from_pretrained; surface a clear error instead of + # letting diffusers raise a confusing model-index + # failure deep in the loader. + if repo_id.lower().endswith("-gguf"): + raise RuntimeError( + f"'{repo_id}' looks like a GGUF-only repo. " + "Either provide gguf_filename to pick a quant, " + "or load a full diffusers repo (base_repo only " + "applies when picking a GGUF quant)." + ) + # ``~/models/my-flux`` must be expanded so + # diffusers' from_pretrained does not pass the + # literal tilde through to ``os.path.isdir`` and + # fall back to the Hub (round 14 P2 #11). + effective_base = _expand_existing_local_path(repo_id) + with self._lock: + self._pending_base_repo = effective_base + elif base_repo: + effective_base = _expand_existing_local_path(base_repo) + # Refresh pending so delete guards see the actual + # base, not just caller-supplied None. + with self._lock: + self._pending_base_repo = effective_base + else: + effective_base = _smart_base_repo(fam, repo_id) + with self._lock: + self._pending_base_repo = effective_base + # ``repo_id`` / ``effective_base`` are user-supplied + # strings that can embed an ``hf_xxxxx`` token via a + # URL-style path (``https://hf_token@huggingface.co/...``). + # Scrub them BEFORE the logger formats the line so the + # token never reaches structured-log sinks (round 14 + # P2 #9). + # Round 23 P2 #11: ``_redact_hf_tokens`` only scrubs + # ``hf_xxxxx`` substrings, so an absolute local + # path like ``/home/alice/private/FLUX.2-klein-GGUF`` + # used to land in this log line verbatim. Route + # through ``_display_repo_id`` so the leaf is + # logged when the value is a filesystem path, with + # the token-redaction step inside that helper as a + # belt-and-braces defence. + logger.info( + "Loading diffusion model %s (family=%s, device=%s, dtype=%s, base=%s)", + _display_repo_id(repo_id), + fam.name, + device, + dtype, + _display_repo_id(effective_base), + ) + + transformer = None + local_gguf_path: Optional[str] = None + if gguf_filename: + if transformer_cls is None: + raise RuntimeError( + f"Family {fam.name} does not have a GGUF transformer " + "path wired in this build; load the full repo instead." + ) + # DiffusionLoadRequest.repo_id is documented to + # accept either a Hub repo id OR a local path + # (Studio export, downloaded HF snapshot, etc.). + # We accept BOTH absolute and relative local + # directories: Studio exports surface as relative + # paths like ``exports/my-flux`` and earlier + # versions only accepted absolute paths, falling + # through to ``hf_hub_download`` which then + # raised HFValidationError on the relative path + # (round 13 P1 #2). For local paths we route the + # gguf_filename through ``_resolve_local_gguf_child`` + # so traversal (``../secret.gguf``) and absolute + # filename escapes (``/etc/passwd``) are rejected + # BEFORE the file is opened, which also keeps the + # delete-ownership guards aligned with what was + # actually loaded. + repo_id_path = Path(repo_id).expanduser() + if repo_id_path.is_dir(): + local_gguf_path = str( + _resolve_local_gguf_child(repo_id_path, gguf_filename) + ) + else: + local_gguf_path = hf_hub_download( + repo_id = repo_id, + filename = gguf_filename, + token = hf_token, + ) + + # Round 20 P1 #1: every load mode (full diffusers + # repo, GGUF + explicit base_repo, GGUF + auto-picked + # base_repo) feeds ``effective_base`` into + # ``from_pretrained`` further down. The round 19 + # preflight only ran for the first two, so an + # auto-picked GGUF companion that turned out to be + # gated / private / missing still unloaded chat + # before the load failed. Always preflight + # ``effective_base`` so a bad companion repo is + # caught BEFORE chat / export are released. + _preflight_full_diffusers_repo(effective_base, hf_token) + # Round 21 P2 #6: the GGUF transformer path also + # consumes ``effective_base`` via + # ``from_single_file(config=effective_base, + # subfolder="transformer")``. A base that has + # ``model_index.json`` but lacks + # ``transformer/config.json`` would pass the + # round-19 preflight and only fail AFTER the chat + # unload. Run the subfolder probe too so the + # second cheap failure mode is also caught early. + if gguf_filename and fam.transformer_class: + _preflight_diffusers_subfolder_config( + effective_base, + "transformer", + hf_token, + ) + + # Round 20 P1 #2: ``diffusers.GGUFQuantizationConfig`` + # imports the ``gguf`` package lazily at construction + # time. Partial Studio installs (``diffusers`` present, + # ``gguf`` not) used to discover that AFTER the chat / + # export release calls. Build the quant config up + # front so the missing-dependency surface raises + # while the user's chat model is still resident. + quant_config = None + if gguf_filename: + try: + quant_config = diffusers.GGUFQuantizationConfig( + compute_dtype = dtype + ) + except ModuleNotFoundError as exc: + missing = exc.name or str(exc) + raise RuntimeError( + "Diffusion GGUF loading requires the gguf " + "runtime package. Missing dependency: " + f"{missing}. Re-run Studio setup before " + "loading an image GGUF." + ) from exc + + # All cheap failure points (bad gguf_filename, missing + # pipeline / transformer class, gated download token, + # transient Hub error on the GGUF download) have now + # been validated. Anything past this line allocates + # GPU memory, so: + # 1. Verify training is idle and the export job (if + # any) is also idle. ``_release_other_gpu_owners + # _for_diffusion`` RAISES on conflict, so it must + # run BEFORE we unload chat (round 16 P1 #2): a + # route precheck -> worker race could otherwise + # drop the user's chat model only to bail out + # because training started in between, and a + # direct ``DiffusionBackend.load_model`` caller + # that did not run the route prechecks would also + # leave chat unloaded for nothing. + # 2. Release the chat backend (llama-server + the + # safetensors orchestrator) now that we know the + # load can actually proceed. + # 3. Release any *previous* diffusion pipeline so the + # new transformer / new from_pretrained does not + # race the old pipe for VRAM. Switching between + # FLUX.2 klein 4B and 9B on a 16-24 GB GPU OOMs + # otherwise: from_single_file allocates the new + # transformer while the old pipeline still owns + # its weights. + # 4. THEN call from_single_file / from_pretrained. + # Round 29 P1 #1: do ALL cheap conflict checks BEFORE + # any destructive unload, so a training/export conflict + # caught inside _release_other_gpu_owners_for_diffusion + # does NOT leave the user with no chat model after we + # already unloaded it. The helper-busy check is + # split out of _release_chat_backend_for_diffusion; + # _release_other_gpu_owners_for_diffusion raises + # RuntimeError early when training/export is active + # without touching the chat backend. + # Round 32 P1 #3: publish a backend-side pending + # entry under the helper-advisor start lock so a + # direct / test / future caller of this method is + # symmetric with the route layer's + # _raise_if_helper_advisor_busy("diffusion"). The + # route's "diffusion" tag and this "diffusion- + # backend" tag refcount independently; both + # contribute to public_load_pending(). + backend_pending_published = _raise_if_helper_advisor_busy_for_diffusion( + publish_pending = True, + ignore_pending_workload = (ignore_public_load_pending_workload), + ) + _release_other_gpu_owners_for_diffusion() + _release_chat_backend_for_diffusion(check_helper_advisor = False) + + old = self._pipe + if old is not None: + with self._lock: + # Clear ALL metadata together so a failed swap + # cannot leave status() reporting the previous + # repo / family / base_repo on top of an empty + # pipe. The except block below will restore + # last_error so the caller knows what happened. + self._pipe = None + self._family = None + self._repo_id = None + self._gguf_path = None + self._gguf_filename = None + self._base_repo = None + self._device = None + self._dtype = None + self._cpu_offload_enabled = False + self._loaded_at = None + _release(old) + old = None + # Now that both the attribute and the local + # have been nulled, the pipeline is unreachable; + # ask the CUDA allocator to release its slabs so + # the next from_pretrained does not OOM behind + # an already-freed-but-cached arena. + _drain_cuda_cache() + + if gguf_filename: + # ``quant_config`` was already constructed above + # (round 20 P1 #2 pre-release fail-fast). + # Diffusers-format GGUFs (FLUX.2 klein / Qwen-Image / + # SD3) need the matching base repo's component config + # at config=, subfolder="transformer". + # Older city96-style GGUFs ignore those kwargs. The + # token is also passed because gated GGUF repos + # require it both at download and at config read time. + single_file_kwargs: dict[str, Any] = { + "quantization_config": quant_config, + "torch_dtype": dtype, + "config": effective_base, + "subfolder": "transformer", + } + if hf_token: + single_file_kwargs["token"] = hf_token + transformer = transformer_cls.from_single_file( + local_gguf_path, + **single_file_kwargs, + ) + + pipe_kwargs: dict[str, Any] = { + "torch_dtype": dtype, + # use_safetensors=True refuses pickle-backed .bin + # weights at load time. Diffusers will fall back to + # safetensors variants on repos that publish both, + # and hard-error on repos that only ship .bin (which + # is the threat model we want to block since pickle + # files can execute arbitrary code in this process). + "use_safetensors": True, + } + if transformer is not None: + pipe_kwargs["transformer"] = transformer + if hf_token: + pipe_kwargs["token"] = hf_token + + pipe = None + cpu_offload_enabled = bool( + enable_model_cpu_offload and device == "cuda" + ) + try: + pipe = pipeline_cls.from_pretrained(effective_base, **pipe_kwargs) + # Device placement / offload can ALSO raise after + # from_pretrained succeeded (OOM at the .to(device) + # copy, accelerate offload hook misconfigured, etc.). + # If we let the exception escape now, the local + # ``pipe`` lives on the traceback frame until the + # caller drops it, holding multi-GB of VRAM behind + # the next load attempt. Explicitly release both + # pipe and transformer in the same try (round 13 + # P2 #11). + if cpu_offload_enabled: + pipe.enable_model_cpu_offload() + else: + pipe.to(device) + except Exception: + if pipe is not None: + _release(pipe) + pipe = None + if transformer is not None: + _release(transformer) + transformer = None + _drain_cuda_cache() + raise + + with self._lock: + self._pipe = pipe + self._family = fam + self._repo_id = repo_id + self._gguf_path = local_gguf_path + # Preserve the full caller-supplied filename, not + # just the basename, so per-variant delete guards + # see ``BF16/model.gguf`` (round 14 P1 #4). + self._gguf_filename = gguf_filename if gguf_filename else None + self._base_repo = effective_base + self._device = device + self._dtype = str(dtype).replace("torch.", "") + self._cpu_offload_enabled = cpu_offload_enabled + self._loaded_at = time.time() + # Clear loading + pending here, BEFORE returning, + # so the response payload reports the resident + # pipeline cleanly (is_loading=false, no pending_*). + # The ``finally`` block below is idempotent and + # still clears on error / early raise paths. + self._loading = False + self._pending_repo_id = None + self._pending_base_repo = None + self._pending_gguf_filename = None + + return self.status() + except Exception as exc: + # Scrub hf_token and pipe_kwargs from frame locals BEFORE + # logger.exception() captures them. Rich tracebacks and + # some structlog formatters render frame locals, which + # would otherwise echo the raw hf_... token into logs + # and any error reporting sink the user has wired up. + # ALSO scrub the exception message itself: huggingface_hub + # / diffusers can include the bearer token verbatim in + # 401 / 403 messages, which would propagate through + # ``_last_error`` (rendered in status()) and the + # user-facing RuntimeError (rendered in route responses). + scrub_token = hf_token + hf_token = None # noqa: F841 + pipe_kwargs = None # noqa: F841 + single_file_kwargs = None # noqa: F841 + exc_msg = str(exc) + if scrub_token: + exc_msg = exc_msg.replace(scrub_token, "") + # Hugging Face tokens are prefixed ``hf_``; replace any + # leftover ``hf_...`` substrings to catch tokens we did + # not store as ``scrub_token`` (e.g. cached tokens that + # huggingface_hub picked up on its own). + import re + + exc_msg = re.sub(r"hf_[A-Za-z0-9]{20,}", "", exc_msg) + + # Round 17 P2 #9: diffusers / safetensors raise errors + # like ``FileNotFoundError: /home/alice/models/foo.gguf`` + # or ``OSError: Error while loading state dict from + # C:\\Users\\bob\\repos\\flux``. These messages flow + # into ``_last_error`` (rendered by status() to every + # authenticated browser tab) and the user-facing + # RuntimeError, which would leak the operator's + # filesystem layout to other sessions. Collapse the + # known repo / base / gguf paths to their leaf name + # using the same convention as _display_repo_id(). + def _collapse_local(msg: str, candidate: Optional[str]) -> str: + if not candidate or not isinstance(candidate, str): + return msg + try: + p = Path(candidate).expanduser() + except (OSError, ValueError): + return msg + leaf = p.name or candidate + needles: set[str] = set() + # Round 20 P2 #6: a relative candidate like + # ``exports/my-flux`` used to collapse only the + # exact ``exports/my-flux`` substring, but + # downstream libraries (diffusers / safetensors) + # resolve and emit ``/mnt/disks/.../exports/my-flux/...`` + # absolute strings that leaked the operator's + # filesystem layout. Also scrub the resolved + # absolute form so the leaf is the only path + # fragment that survives. + try: + if p.exists(): + needles.add(str(p.resolve())) + elif p.is_absolute(): + needles.add(str(p)) + except (OSError, ValueError): + pass + if "/" in candidate or "\\" in candidate: + needles.add(candidate) + # Replace longest first so a parent-directory + # substring does not blank out the leaf-only + # context the user needs. + for needle in sorted( + (n for n in needles if n and n != leaf), + key = len, + reverse = True, + ): + msg = msg.replace(needle, leaf) + return msg + + # ``effective_base`` and ``gguf_filename`` are local + # to the try block above and may be unbound if the + # exception fired before assignment (e.g. the GGUF + # repo / filename validation raises before + # ``effective_base`` is computed). ``locals().get`` + # keeps the scrub a no-op in that case. + # Round 18 P2 #9: also scrub ``local_gguf_path``. The + # GGUF quant is loaded via + # ``transformer_cls.from_single_file(local_gguf_path)``, + # and diffusers / safetensors errors include the + # resolved absolute HF cache path + # (``/home/alice/.cache/huggingface/hub/.../flux.gguf``). + # Without this the cache path would leak into + # ``_last_error`` (and therefore status() / log lines). + _locals = locals() + exc_msg = _collapse_local(exc_msg, repo_id) + exc_msg = _collapse_local(exc_msg, _locals.get("effective_base")) + exc_msg = _collapse_local(exc_msg, _locals.get("gguf_filename")) + exc_msg = _collapse_local(exc_msg, _locals.get("local_gguf_path")) + with self._lock: + self._last_error = exc_msg + # ``logger.exception`` would emit the raw exception + # (including any unredacted ``hf_...`` token inside + # the message OR traceback locals on rich loggers). + # Use ``logger.error`` with the already-scrubbed + # message and exc_info=False so the bearer token + # cannot leak through structured logging sinks. + # Round 23 P2 #12: same fix as the start-of-load + # log above. ``_redact_hf_tokens`` alone left + # absolute local repo paths in this failure line. + logger.error( + "Diffusion load failed for %s: %s", + _display_repo_id(repo_id), + exc_msg, + ) + raise RuntimeError( + f"Failed to load diffusion model: {exc_msg}" + ) from exc + finally: + with self._lock: + self._loading = False + # Clear pending so status() falls back to publishing + # the resident pipeline (or nothing, on a failed + # swap). Keeping pending alive after the load + # finishes would falsely block deletes forever. + self._pending_repo_id = None + self._pending_base_repo = None + self._pending_gguf_filename = None + # Round 32 P1 #3: clear the backend-side public-load + # pending publish if it was set. Skipped when the + # helper-busy snapshot raised (no publish to clear) + # so the counter stays in sync with publishes. + if backend_pending_published: + _clear_diffusion_backend_pending() + + def unload_model(self) -> dict[str, Any]: + # Take the load lock and the generate lock so unload cannot: + # * race with an in-flight load_model and have the load + # thread overwrite the cleared state after we already + # returned {"is_loaded": false}. + # * return is_loaded=false while a forward pass is still + # iterating denoising steps on the soon-to-be-freed pipe. + # The generate forward only holds _generate_lock (briefly + # _lock), so acquiring _generate_lock here blocks until any + # in-flight generation completes. + with self._load_lock, self._generate_lock: + with self._lock: + old = self._pipe + # Mark the slot as busy BEFORE clearing _pipe so a + # concurrent helper-busy check (which treats either + # is_loaded OR is_loading as busy) does not see a + # ``free`` GPU during the release + cache-drain window. + # is_loading is cleared in finally once the VRAM is + # actually freed. + self._loading = True + self._pipe = None + self._family = None + self._repo_id = None + self._gguf_path = None + self._gguf_filename = None + self._base_repo = None + self._device = None + self._dtype = None + self._cpu_offload_enabled = False + self._loaded_at = None + try: + _release(old) + old = None # noqa: F841 + _drain_cuda_cache() + finally: + with self._lock: + self._loading = False + return {"is_loaded": False} + + # ── generation ──────────────────────────────────────────────── + + def generate_image( + self, + *, + prompt: str, + negative_prompt: Optional[str] = None, + num_inference_steps: int = 24, + guidance_scale: float = 3.5, + width: int = 1024, + height: int = 1024, + seed: Optional[int] = None, + ) -> "Any": + """Generate a single PIL image and return it. + + Concurrent generations are serialised by ``_generate_lock`` so + diffusion pipelines (not thread-safe; overlapping ``__call__``s + corrupt internal scheduler state) only ever run one at a time. + The state ``_lock`` is taken only to snapshot ``_pipe`` / + ``_device`` and immediately released: holding it for the whole + forward pass blocked ``status()`` polls and concurrent unload + requests for the entire (minutes-long) generation, which made + the UI feel frozen. + """ + # Take _generate_lock FIRST so a concurrent unload/load that + # observes us holding it will queue behind this generation + # (and `unload_model` then waits its turn before clearing + # state). Snapshotting `self._pipe` outside the lock and then + # taking the lock let a load/unload race in between, so the + # forward could run against a freed or swapped pipeline. + with self._generate_lock: + return self._generate_image_unlocked( + prompt = prompt, + negative_prompt = negative_prompt, + num_inference_steps = num_inference_steps, + guidance_scale = guidance_scale, + width = width, + height = height, + seed = seed, + ) + + def _generate_image_unlocked( + self, + *, + prompt: str, + negative_prompt: Optional[str] = None, + num_inference_steps: int = 24, + guidance_scale: float = 3.5, + width: int = 1024, + height: int = 1024, + seed: Optional[int] = None, + ) -> "Any": + """Inner body of ``generate_image`` that ASSUMES the caller + already holds ``_generate_lock``. Lets + ``generate_image_with_metadata`` snapshot metadata under the + same lock without deadlocking on a non-reentrant + ``threading.Lock`` (round 13 P2 #9).""" + if not prompt or not prompt.strip(): + raise ValueError("prompt is empty") + if num_inference_steps < 1 or num_inference_steps > 200: + raise ValueError("num_inference_steps must be in [1, 200]") + if width <= 0 or height <= 0 or width > 2048 or height > 2048: + raise ValueError("width and height must be in (0, 2048]") + # Snap to a multiple of 8: Flux / SD pipelines require it and a + # silent crash deep in the VAE is much worse than a clear error + # message up front. + if width % 8 or height % 8: + raise ValueError("width and height must be multiples of 8") + + import torch + + with self._lock: + if self._pipe is None: + raise RuntimeError("No diffusion model is loaded.") + pipe = self._pipe + device = self._device or "cpu" + cpu_offload_enabled = self._cpu_offload_enabled + generator = None + if seed is not None: + # Match the device of the pipeline so determinism holds + # across reload cycles. When CPU offload is enabled + # (the default on CUDA hosts), diffusers shuttles each + # submodule between CPU and GPU on every step. A CUDA + # torch.Generator then mismatches the CPU-resident + # embeddings at the start of the forward and the run + # crashes (round 14 P1 #6). Use a CPU generator in that + # case; numerical determinism for the same seed is + # preserved because the seed feeds an int rather than a + # device-local RNG state. + if cpu_offload_enabled: + gen_device = "cpu" + else: + gen_device = ( + "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu" + ) + generator = torch.Generator(device = gen_device).manual_seed(int(seed)) + + call_kwargs: dict[str, Any] = { + "prompt": prompt, + "num_inference_steps": int(num_inference_steps), + "guidance_scale": float(guidance_scale), + "width": int(width), + "height": int(height), + } + # FLUX.2 / FLUX.2 klein pipelines do NOT accept + # negative_prompt and 500 if you pass it in. Inspect the + # signature and only forward when supported; warn otherwise + # so the UI can disable the field for incompatible families. + if negative_prompt is not None and negative_prompt.strip(): + if _pipe_accepts_kwarg(pipe, "negative_prompt"): + call_kwargs["negative_prompt"] = negative_prompt + # QwenImagePipeline and FluxPipeline treat + # guidance_scale as distilled CFG and use + # true_cfg_scale as the real classifier-free + # guidance knob; the negative prompt is only + # effective when true_cfg_scale > 1. Forward the + # user-supplied guidance_scale through both so the + # negative prompt actually steers generation. + if _pipe_accepts_kwarg(pipe, "true_cfg_scale"): + call_kwargs["true_cfg_scale"] = float(guidance_scale) + else: + logger.info( + "Dropping negative_prompt: %s does not accept it", + type(pipe).__name__, + ) + if generator is not None: + call_kwargs["generator"] = generator + + out = pipe(**call_kwargs) + images = getattr(out, "images", None) or [] + if not images: + raise RuntimeError("Diffusion pipeline returned no images.") + return images[0] + + def generate_image_with_metadata( + self, + **kwargs: Any, + ) -> tuple[Any, dict[str, Any]]: + """Generate a single image AND snapshot its identifying metadata. + + Returns ``(pil_image, {"model": , "family": })`` + where the metadata reflects the pipeline that produced the + image. Snapshotted under ``_generate_lock + _lock`` so a + queued unload / load that promotes a different pipeline + cannot replace ``self._repo_id`` / ``self._family`` between + the forward returning and the route reading status (round + 13 P2 #9). The route uses these values directly in the + response instead of re-calling ``status()``. + """ + with self._generate_lock: + image = self._generate_image_unlocked(**kwargs) + with self._lock: + # Round 16 P1 #6: route ``model`` through + # _display_repo_id so a generation response for a + # locally-loaded model cannot echo back an absolute + # filesystem path to the browser. + meta = { + "model": _display_repo_id(self._repo_id), + "family": self._family.name if self._family else None, + } + return image, meta + + +def _pipe_accepts_kwarg(pipe: Any, name: str) -> bool: + """True if ``pipe.__call__`` advertises a kwarg called ``name``. + + Cheap inspect-based probe so we do not have to maintain a manual + list of which pipeline classes accept negative_prompt. Returns + False on any introspection error so callers stay on the safe path. + """ + import inspect + + try: + sig = inspect.signature(pipe.__call__) + except (TypeError, ValueError): + return False + if name in sig.parameters: + return True + return any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + + +def encode_png_base64(pil_image: "Any") -> str: + """Encode a PIL image to base64-encoded PNG.""" + import base64 + + buf = io.BytesIO() + pil_image.save(buf, format = "PNG", optimize = True) + return base64.b64encode(buf.getvalue()).decode("ascii") + + +# ─── Helpers ────────────────────────────────────────────────────────── + + +def _raise_if_helper_advisor_busy_for_diffusion( + *, + publish_pending: bool = False, + ignore_pending_workload: Optional[str] = None, +) -> bool: + """Round 29 P1 #1: split the helper-busy check out of + _release_chat_backend_for_diffusion so the diffusion load can + check ALL conflicts (helper, training, export) BEFORE doing ANY + destructive unloads. Otherwise a route-precheck race or a direct + backend call would unload the user's chat while training was + active, then 409 with the user holding no model at all. + + Round 32 P1 #3: when ``publish_pending=True`` also takes + ``_HELPER_ADVISOR_START_LOCK`` and publishes a + ``diffusion-backend`` public-load pending entry so a concurrent + AI Assist helper / advisor start that wins the start lock sees + the pending public owner and refuses VRAM. The route layer + publishes its own ``diffusion`` tag (refcount semantics, so the + two publishes coexist without erasing each other). Returns True + when a pending entry was actually published so the caller can + pair it with ``_clear_diffusion_backend_pending`` in finally. + Direct callers (tests, scripts) opt in with ``publish_pending= + True`` to get the same atomic check + publish the route gets. + The ``check_helper_advisor`` callback in + ``_release_chat_backend_for_diffusion`` keeps the default False + so legacy callers do not double-publish or leak pending entries. + """ + try: + from utils.datasets.llm_assist import ( + _HELPER_ADVISOR_START_LOCK, + _publish_public_load_pending, + helper_advisor_busy, + public_load_pending, + ) + except Exception: + return False + with _HELPER_ADVISOR_START_LOCK: + if helper_advisor_busy(): + raise RuntimeError( + "AI Assist (helper / advisor GGUF) is still using the GPU. " + "Wait for it to finish before loading a diffusion image model." + ) + # Round 38 P1: mirror the route-side _raise_if_helper_advisor_busy + # public_load_pending parity check. When publishing, refuse if + # ANOTHER public workload is already mid-handoff. Route-wrapped + # calls pass ignore_pending_workload="diffusion" so the + # route's own publish (which happened just before + # backend.load_model) does not cause the backend's atomic + # check to self-block. + if publish_pending and public_load_pending(excluding = ignore_pending_workload): + raise RuntimeError( + "Another GPU workload is mid-handoff. Wait for it to " + "finish before loading a diffusion image model." + ) + if publish_pending: + _publish_public_load_pending("diffusion-backend") + return True + return False + + +def _clear_diffusion_backend_pending() -> None: + """Round 32 P1 #3: paired clear for + ``_raise_if_helper_advisor_busy_for_diffusion(publish_pending=True)``. + Safe to call when the helpers module is unavailable (no-op).""" + try: + from utils.datasets.llm_assist import _release_public_load_pending + except Exception: + return + try: + _release_public_load_pending("diffusion-backend") + except Exception: + pass + + +def _release_chat_backend_for_diffusion(*, check_helper_advisor: bool = True) -> None: + """Unload any running chat backend before a diffusion load. + + Diffusion pipelines on FLUX-class models can eat 12-24 GB of VRAM, + and the chat backends (llama-server for GGUF, the safetensors + Inference orchestrator for HF / Unsloth) typically hold onto their + loaded weights until told to drop them. Asking both to release + their weights first means a typical 24 GB consumer GPU can host + one chat model OR one diffusion model without manual unload steps. + + A missing chat backend module is a silent no-op (fresh install / + no GGUF use). An unload that ACTUALLY fails (raises or leaves + the backend resident) raises ``RuntimeError`` so the surrounding + diffusion ``load_model`` bails out instead of double-owning VRAM + (round 17 P1 #2). + """ + # Round 27 P1 #2 / round 29 P1 #1: helper / advisor GGUF loads + # run on a PRIVATE LlamaCppBackend so the global llama check below + # cannot see them. The actual busy check now lives in + # _raise_if_helper_advisor_busy_for_diffusion so the caller can do + # ALL conflict checks BEFORE any destructive unload. Kept here as + # a default-on safety net for callers that did not run the + # standalone check. + if check_helper_advisor: + _raise_if_helper_advisor_busy_for_diffusion() + # 1. GGUF chat backend (llama-server subprocess). We unload when + # EITHER is_loaded is True (resident model) OR is_active is + # True (mid-download / startup) OR loading_model_identifier is + # populated (HF GGUF download in progress, before is_active / + # is_loaded flip). The last case is what round 13 P1 #8 flagged. + try: + from routes.inference import get_llama_cpp_backend # type: ignore + except Exception as exc: + logger.debug("llama-server unavailable before diffusion load: %s", exc) + else: + backend = get_llama_cpp_backend() + is_loaded = bool(getattr(backend, "is_loaded", False)) + is_active = bool(getattr(backend, "is_active", False)) + is_loading = bool(getattr(backend, "loading_model_identifier", None)) + if is_loaded or is_active or is_loading: + logger.info( + "Unloading llama-server (loaded=%s active=%s loading=%s) before diffusion load", + is_loaded, + is_active, + is_loading, + ) + try: + ok = backend.unload_model() + except Exception as exc: + raise RuntimeError( + "Could not unload the existing GGUF chat model before " + "loading a diffusion image model." + ) from exc + # Round 28 P1 #12: a cancelled pending GGUF download takes + # up to a few seconds to clear loading_model_identifier in + # its finally block. Wait briefly so the same retryable + # cancel path used by the unload route does not 503 us. + deadline = time.monotonic() + 5.0 + while ( + getattr(backend, "loading_model_identifier", None) + and time.monotonic() < deadline + ): + time.sleep(0.1) + # Round 18 P1 #4: also reject when ``loading_model_identifier`` + # is still set after the unload call. Without this, a GGUF + # download / startup that was already in flight before the + # diffusion handoff (and which never flipped is_active to + # True before the unload landed) keeps allocating into VRAM + # while diffusion proceeds, double-owning the GPU. + if ( + ok is False + or getattr(backend, "is_loaded", False) + or getattr(backend, "is_active", False) + or getattr(backend, "loading_model_identifier", None) + ): + raise RuntimeError( + "The existing GGUF chat model is still active or loading " + "after unload; retry before loading a diffusion image model." + ) + + # 2. Safetensors / HF chat backend (the InferenceOrchestrator that + # serves FastVisionModel / FastLanguageModel weights). When this + # backend has a model resident on the same GPU, a diffusion load + # will OOM the same way. We also flush any loading_models set so + # a chat load that is mid-download cannot race the diffusion + # allocation. + try: + from core.inference import get_inference_backend # type: ignore + except Exception as exc: + logger.debug("safetensors unavailable before diffusion load: %s", exc) + return + + backend = get_inference_backend() + active_model_name = getattr(backend, "active_model_name", None) + loading_models = set(getattr(backend, "loading_models", set()) or set()) + + def _require_unload(model_name: str) -> None: + try: + ok = backend.unload_model(model_name) + except Exception as exc: + raise RuntimeError( + f"Could not unload safetensors chat model '{model_name}' " + "before loading a diffusion image model." + ) from exc + if ok is False: + raise RuntimeError( + f"Safetensors backend refused to unload '{model_name}' " + "before loading a diffusion image model." + ) + # Round 19 P1 #2: per-name post-state check. ``unload_model`` + # returning ``True`` does not guarantee the orchestrator + # actually dropped the weights; the worker may have responded + # while still holding them, or a concurrent ``load`` may have + # repopulated the tracker. Verify the specific name is gone + # so the surrounding diffusion load bails out instead of + # silently double-owning VRAM. + active_after = getattr(backend, "active_model_name", None) + loading_after = set(getattr(backend, "loading_models", set()) or set()) + if active_after == model_name or model_name in loading_after: + raise RuntimeError( + f"Safetensors chat model '{model_name}' is still active " + "or loading after unload; retry before loading a diffusion image model." + ) + + if active_model_name: + logger.info( + "Unloading safetensors chat backend '%s' before diffusion load", + active_model_name, + ) + _require_unload(active_model_name) + for loading in loading_models: + if loading == active_model_name: + continue + logger.info( + "Unloading in-flight safetensors chat load '%s' before diffusion", + loading, + ) + _require_unload(loading) + + # Round 21 P1 #5: final sweep without the owned_names filter. + # A concurrent ``/load`` that appeared AFTER the initial + # snapshot was previously ignored, so a chat model that started + # loading during the diffusion handoff slipped through and + # raced the diffusion allocation for VRAM. Treat ANY surviving + # active / loading entry as a failure so the surrounding + # load_model raises and the caller retries. + remaining_loading = set(getattr(backend, "loading_models", set()) or set()) + remaining_active = getattr(backend, "active_model_name", None) + if remaining_loading or remaining_active: + raise RuntimeError( + "A safetensors chat model is still active or loading " + "after unload; retry before loading a diffusion image model." + ) + + +def _release_other_gpu_owners_for_diffusion() -> None: + """Best-effort: shut down export subprocess + active training before + a diffusion load. Both can hold multi-GB of VRAM and would OOM the + diffusion allocation on consumer GPUs.""" + # Export resident checkpoint. We tear down a SETTLED export + # (current_checkpoint populated AND is_export_active() False) + # because that means the export ran to completion and the user + # can re-load the result. An in-flight export job + # (is_export_active() True) is NEVER touched here: terminating + # it would corrupt the user's partial output artifact. + # + # The route layer also rejects /images/load with HTTP 409 via + # _raise_if_export_active when is_export_active() is True. This + # helper repeats the local check anyway so that direct backend + # callers (tests, scripts, future routes that forget the + # higher-level guard) cannot still kill an active export. + # Training-active check runs FIRST so direct backend callers + # (tests, scripts, future routes) cannot bypass the route layer's + # 409 by calling ``load_model`` directly while a training run is + # active (round 15 P1 #3). The route layer's + # ``_raise_if_training_active`` still runs ahead of the load to + # surface the conflict as 409; this helper re-raises so direct + # callers see the same RuntimeError the export-active path raises. + try: + from core.training import get_training_backend # type: ignore + except Exception as exc: + logger.debug("training module not importable: %s", exc) + else: + try: + training_active = bool(get_training_backend().is_training_active()) + except Exception as exc: + # Unverifiable status -> fail closed (might be active). + raise RuntimeError( + "Could not verify training status before loading a " + "diffusion image model." + ) from exc + if training_active: + raise RuntimeError( + "Training is currently active. Stop the training run " + "before loading a diffusion image model." + ) + + try: + from core.export import get_export_backend # type: ignore + except Exception as exc: + logger.debug("export module not importable: %s", exc) + return + + # Round 18 P1 #6: ``get_export_backend()`` raising used to be a + # silent ``return`` so direct ``DiffusionBackend.load_model`` + # callers could proceed toward GPU allocation without being able + # to verify export ownership. Fail closed instead, matching the + # route-level helper which already maps "Could not verify" / + # "Could not access" failures to HTTP 503. + try: + exp = get_export_backend() + except Exception as exc: + raise RuntimeError( + "Could not verify export status before loading a " "diffusion image model." + ) from exc + + is_export_active_fn = getattr(exp, "is_export_active", None) + if is_export_active_fn is not None: + try: + export_is_active = bool(is_export_active_fn()) + except Exception as exc: + # Round 16 P2 #8: distinguish unverifiable status from + # active export. The previous "treat as active" mapping + # surfaced as a misleading 409 conflict; raise a + # "Could not verify" RuntimeError so the route layer + # maps it to 503 (retryable) instead. + raise RuntimeError( + "Could not verify export status before loading a " + "diffusion image model." + ) from exc + if export_is_active: + # Round 14 P2 #10: the prior behaviour logged a warning + # and continued, so direct ``DiffusionBackend.load_model`` + # callers (tests, scripts) silently bypassed the route + # layer's 409. Hard-refuse instead so any code path that + # reaches this helper while an export is active sees the + # same failure mode the route returns. + raise RuntimeError( + "An export job is currently active. Stop the export " + "job before loading a diffusion image model." + ) + + if getattr(exp, "current_checkpoint", None): + # Round 18 P1 #2: a wedged ``_shutdown_subprocess`` used to log + # at debug level and continue, so direct backend callers could + # allocate diffusion VRAM on top of an export checkpoint that + # still owned the GPU. Mirror the route-level helper and raise + # so the surrounding ``load_model`` bails out with a clean + # RuntimeError that the route layer maps to HTTP 503. + try: + logger.info("Shutting down idle export subprocess before diffusion load") + exp._shutdown_subprocess() + except Exception as exc: + raise RuntimeError( + "Could not unload the idle export checkpoint before " + "loading a diffusion image model." + ) from exc + exp.current_checkpoint = None + exp.is_vision = False + exp.is_peft = False + + # Note: active training is *not* stopped here. The route layer + # (`_raise_if_training_active` in routes/inference.py) refuses + # /images/load with HTTP 409 before this helper runs, so reaching + # this point with training still active would only happen in + # programmatic backend calls (tests, scripts). Silently terminating + # someone's training run when the diffusion load might still fail + # is worse than letting the load OOM and surfacing it explicitly. + + +def _release(obj: Any) -> None: + """Best-effort GPU-memory release for a pipeline being swapped out. + + Only drops the local reference (which the caller has already + nulled in its own scope) and runs ``gc.collect()`` so __del__ + fires. Does NOT call ``torch.cuda.empty_cache()`` here because + when the caller still holds the actual reference in a local / + attribute, ``empty_cache()`` would run before __del__ released + the weights and would not actually free GPU memory. Use + ``_drain_cuda_cache()`` AFTER the last reference has been nulled. + """ + if obj is None: + return + try: + del obj + except Exception: + pass + gc.collect() + + +def _drain_cuda_cache() -> None: + """Hand freed weights back to the active accelerator's allocator. + + Call this AFTER every reference to the freed object has been + dropped (caller's local + attribute) and a ``gc.collect()`` has + fired __del__. Calling earlier would empty an already-pinned + cache and not actually release the memory. + + Handles CUDA *and* MPS (Apple Silicon) so a diffusion swap on + macOS actually returns VRAM to the Metal allocator. + """ + try: + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + try: + import torch + + mps_backend = getattr(getattr(torch, "backends", None), "mps", None) + if mps_backend is not None and mps_backend.is_available(): + mps_module = getattr(torch, "mps", None) + empty_cache = ( + getattr(mps_module, "empty_cache", None) if mps_module else None + ) + if empty_cache is not None: + empty_cache() + except Exception: + pass + + +# ─── Module-level singleton ─────────────────────────────────────────── + + +_singleton: Optional[DiffusionBackend] = None +_singleton_lock = threading.Lock() + + +def get_diffusion_backend() -> DiffusionBackend: + """Return the process-wide diffusion backend (lazy-instantiated).""" + global _singleton + if _singleton is None: + with _singleton_lock: + if _singleton is None: + _singleton = DiffusionBackend() + return _singleton + + +async def async_generate( + backend: DiffusionBackend, + **kwargs: Any, +) -> "Any": + """Run ``generate_image`` in the default executor so route handlers + do not block the event loop for the 5-30 s a diffusion step takes.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: backend.generate_image(**kwargs)) + + +async def async_generate_with_metadata( + backend: DiffusionBackend, + **kwargs: Any, +) -> tuple[Any, dict[str, Any]]: + """Run ``generate_image_with_metadata`` in the default executor. + + Used by the /images/generate route so the response model / family + fields reflect the pipeline that actually produced the image, even + if an unload races the route between the forward returning and the + response being assembled (round 13 P2 #9).""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: backend.generate_image_with_metadata(**kwargs), + ) diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index bf8a3c04df..f3f282893a 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -612,6 +612,18 @@ def __init__(self): self._process: Optional[subprocess.Popen] = None self._port: Optional[int] = None self._model_identifier: Optional[str] = None + # Pending-load identifier: set BEFORE _download_gguf starts and + # cleared after the load finishes (success or failure). Delete + # guards and cross-workload handoff helpers read it via + # ``loading_model_identifier`` so a multi-GB HF download cannot + # have its cache rmtree'd or be ignored by /images/load, + # /training/start, /export/load while it is still resolving. + # ``_loading_hf_variant`` mirrors the same lifetime so the + # per-variant delete guard at routes/models.py:/delete-finetuned + # compares against the NEW variant rather than the previous + # loaded ``hf_variant`` (round 15 P1 #2). + self._loading_model_identifier: Optional[str] = None + self._loading_hf_variant: Optional[str] = None self._gguf_path: Optional[str] = None self._hf_repo: Optional[str] = None self._hf_variant: Optional[str] = None @@ -713,6 +725,33 @@ def base_url(self) -> str: def model_identifier(self) -> Optional[str]: return self._model_identifier + @property + def loading_model_identifier(self) -> Optional[str]: + """Identifier of a load currently in progress, or None. + + Populated while ``_download_gguf`` is fetching the GGUF for a + new ``load_model`` call. Cleared in the surrounding + ``finally`` block, so a failed load leaves it None. Delete + guards in ``routes/models.py`` and handoff helpers in + ``routes/inference.py`` consult this so a long HF download + cannot have its destination rmtree'd or be ignored by a + concurrent /images/load that thinks llama-server is idle.""" + return self._loading_model_identifier + + @property + def loading_hf_variant(self) -> Optional[str]: + """``hf_variant`` of the load currently in progress, or None. + + Mirrors ``loading_model_identifier``'s lifetime so the + per-variant delete guards (routes/models.py /delete-cached and + /delete-finetuned) can compare against the NEW variant rather + than the previously-loaded one (round 15 P1 #2). Without this, + a directory with Q4 loaded and Q8 loading would still see the + stale Q4 ``hf_variant``, and a Q8 delete would be wrongly + allowed even though Q8 is being downloaded into the same + directory.""" + return self._loading_hf_variant + @property def is_vision(self) -> bool: return self._is_vision @@ -2599,7 +2638,68 @@ def load_model( # Serialise the whole load so concurrent /load calls never # leave two llama-server processes alive (#5401 / #5161). Does # not block /unload, /status, /load-progress. + # + # Publish ``_loading_model_identifier`` + ``_loading_hf_variant`` + # AFTER acquiring ``_serial_load_lock``. Round 15 P1 #1: the + # previous round 14 version set them outside the lock so a + # second queued ``load_model`` would overwrite or clear the + # identifier of the load currently holding the lock, breaking + # the delete-safety and GPU handoff guards. Cleared in + # ``finally`` so failure / cancellation leaves the pending + # state empty. Round 15 P1 #2 added ``_loading_hf_variant`` + # so per-variant delete guards can compare against the + # NEW variant rather than the previous loaded one. with self._serial_load_lock: + self._loading_model_identifier = model_identifier + self._loading_hf_variant = hf_variant + try: + return self._load_model_impl_locked( + gguf_path = gguf_path, + mmproj_path = mmproj_path, + hf_repo = hf_repo, + hf_variant = hf_variant, + hf_token = hf_token, + model_identifier = model_identifier, + is_vision = is_vision, + n_ctx = n_ctx, + chat_template_override = chat_template_override, + cache_type_kv = cache_type_kv, + speculative_type = speculative_type, + spec_draft_n_max = spec_draft_n_max, + n_threads = n_threads, + n_gpu_layers = n_gpu_layers, + n_parallel = n_parallel, + extra_args = extra_args, + ) + finally: + self._loading_model_identifier = None + self._loading_hf_variant = None + + def _load_model_impl_locked( + self, + *, + gguf_path: Optional[str] = None, + mmproj_path: Optional[str] = None, + hf_repo: Optional[str] = None, + hf_variant: Optional[str] = None, + hf_token: Optional[str] = None, + model_identifier: str, + is_vision: bool = False, + n_ctx: int = 4096, + chat_template_override: Optional[str] = None, + cache_type_kv: Optional[str] = None, + speculative_type: Optional[str] = None, + spec_draft_n_max: Optional[int] = None, + n_threads: Optional[int] = None, + n_gpu_layers: Optional[int] = None, + n_parallel: int = 1, + extra_args: Optional[List[str]] = None, + ) -> bool: + """Internal body of ``load_model``. The caller is responsible + for holding ``_serial_load_lock`` and for publishing / + clearing ``_loading_model_identifier`` + ``_loading_hf_variant`` + in the surrounding try/finally.""" + if True: # Duplicate /load that raced past the route-level check # (the first one hadn't published _healthy=True yet). If the # live server already satisfies this request, do nothing. diff --git a/studio/backend/main.py b/studio/backend/main.py index 004ae404cd..ca862b1ccb 100644 --- a/studio/backend/main.py +++ b/studio/backend/main.py @@ -293,6 +293,70 @@ def _precache(): lifespan = lifespan, ) + +# ── Validation error scrubber ──────────────────────────────────── +# Round 16 P2 #10: FastAPI's default RequestValidationError handler +# echoes the rejected ``input`` value back in the 422 body. A +# request like +# {"repo_id": "https://hf_token@huggingface.co/owner/repo"} +# is rejected by ``DiffusionLoadRequest._no_embedded_hf_tokens``, +# but the rejected URL would still appear in the response payload, +# leaking the token to the browser console / network log. Wrap the +# handler so any ``hf_xxxxx`` substring is replaced with +# ```` before serialisation. Scoped to the response body +# only; the underlying validator behaviour is unchanged. +from fastapi.exceptions import RequestValidationError as _RequestValidationError # noqa: E402 +from fastapi.encoders import jsonable_encoder as _jsonable_encoder # noqa: E402 +from fastapi.responses import JSONResponse as _JSONResponse # noqa: E402 +import re as _re_validation # noqa: E402 + + +_HF_TOKEN_VALIDATION_RE = _re_validation.compile(r"hf_[A-Za-z0-9]{20,}") + + +def _scrub_validation_obj(value): + """Recursively scrub ``hf_xxxxx`` tokens out of a value tree. + + Pydantic v2 nests raw ``ValueError`` (and other ``BaseException``) + instances under ``ctx.error``. Convert them to scrubbed strings + here; otherwise the default ``JSONResponse`` serializer raises + ``TypeError: Object of type ValueError is not JSON serializable`` + and the 422 turns into a 500 (round 17 P1 #1). Tuples become + lists so the downstream JSON encoder accepts them. + """ + if isinstance(value, str): + return _HF_TOKEN_VALIDATION_RE.sub("", value) + if isinstance(value, BaseException): + return _scrub_validation_obj(str(value)) + if isinstance(value, tuple): + return [_scrub_validation_obj(v) for v in value] + if isinstance(value, list): + return [_scrub_validation_obj(v) for v in value] + if isinstance(value, dict): + # Round 21 P2 #7: pydantic surfaces ``input`` for ``string_type`` + # validation errors verbatim, including dict KEYS like + # ``{"hf_xxxxx": "owner/repo"}``. Scrub string keys too so the + # token does not leak through the 422 response body. + return { + ( + _scrub_validation_obj(k) if isinstance(k, str) else k + ): _scrub_validation_obj(v) + for k, v in value.items() + } + return value + + +@app.exception_handler(_RequestValidationError) +async def _validation_error_scrubbing_handler(request, exc): + # ``jsonable_encoder`` walks the scrubbed payload one more time + # to convert anything else Pydantic v2 surfaces (URL objects, + # Path objects, Url instances, etc.) into JSON-safe primitives. + return _JSONResponse( + status_code = 422, + content = _jsonable_encoder({"detail": _scrub_validation_obj(exc.errors())}), + ) + + # Initialize structured logging from loggers.config import LogConfig from loggers.handlers import LoggingMiddleware diff --git a/studio/backend/models/data_recipe.py b/studio/backend/models/data_recipe.py index b382ddb3d0..0bf3ce65bd 100644 --- a/studio/backend/models/data_recipe.py +++ b/studio/backend/models/data_recipe.py @@ -9,7 +9,13 @@ from typing import Any -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator + +# Round 23 P1 #5: identifier hardening reused from the chat models +# so /api/data_recipe/publish rejects control characters and +# URL-form ``hf_xxxxx`` tokens in ``repo_id`` before they reach +# log lines or the HF API. +from models.inference import _no_control_chars, _reject_embedded_hf_token class RecipePayload(BaseModel): @@ -60,6 +66,16 @@ class PublishDatasetRequest(BaseModel): description = "Execution artifact path captured by the UI for completed runs", ) + @field_validator("repo_id") + @classmethod + def _no_repo_id_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("repo_id") + @classmethod + def _no_repo_id_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class PublishDatasetResponse(BaseModel): success: bool = True @@ -74,6 +90,20 @@ class SeedInspectRequest(BaseModel): split: str | None = "train" preview_size: int = Field(default = 10, ge = 1, le = 50) + # Round 26 P1 #11: dataset_name reaches HF + log/echo paths, so + # mirror the hardening other dataset request models already do. + # Round 27 P1 #7: split and subset also flow into HF dataset + # APIs / errors and must be guarded the same way. + @field_validator("dataset_name", "subset", "split") + @classmethod + def _no_dataset_name_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("dataset_name", "subset", "split") + @classmethod + def _no_dataset_name_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class SeedInspectUploadRequest(BaseModel): # Legacy single-file flow (mutually exclusive with file_ids) @@ -89,6 +119,37 @@ class SeedInspectUploadRequest(BaseModel): unstructured_chunk_size: int | None = Field(default = None, ge = 1, le = 20000) unstructured_chunk_overlap: int | None = Field(default = None, ge = 0, le = 20000) + # Round 30 P1 #6: filename / file_names are reflected as dataset + # names + error/log messages; harden them the same way the sibling + # SeedInspectRequest hardens dataset_name. + @field_validator("filename") + @classmethod + def _no_filename_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("filename") + @classmethod + def _no_filename_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + + @field_validator("file_names") + @classmethod + def _no_file_names_control_chars(cls, v): + if v is None: + return v + for i, entry in enumerate(v): + _no_control_chars(entry, f"file_names[{i}]") + return v + + @field_validator("file_names") + @classmethod + def _no_file_names_embedded_hf_tokens(cls, v): + if v is None: + return v + for i, entry in enumerate(v): + _reject_embedded_hf_token(entry, f"file_names[{i}]") + return v + @model_validator(mode = "after") def _check_mutual_exclusivity(self) -> "SeedInspectUploadRequest": has_legacy = self.content_base64 is not None diff --git a/studio/backend/models/datasets.py b/studio/backend/models/datasets.py index f20d6f2d15..28a4016514 100644 --- a/studio/backend/models/datasets.py +++ b/studio/backend/models/datasets.py @@ -7,7 +7,12 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator + +# Round 24 P1 #11: reuse the chat / diffusion / export identifier +# hardening so dataset routes also reject control characters and +# URL-embedded HF tokens in user-controlled identifiers. +from models.inference import _no_control_chars, _reject_embedded_hf_token class CheckFormatRequest(BaseModel): @@ -27,6 +32,18 @@ def _compat_split(cls, values: Any) -> Any: values.setdefault("train_split", values.pop("split")) return values + # Round 27 P1 #6: subset / train_split also flow into HF dataset + # APIs and errors/responses, so they need the same hardening. + @field_validator("dataset_name", "subset", "train_split") + @classmethod + def _no_dataset_name_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("dataset_name", "subset", "train_split") + @classmethod + def _no_dataset_name_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class CheckFormatResponse(BaseModel): """Response for dataset format check""" @@ -57,6 +74,16 @@ class AiAssistMappingRequest(BaseModel): model_name: Optional[str] = None model_type: Optional[str] = None + @field_validator("dataset_name", "model_name") + @classmethod + def _no_identifier_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("dataset_name", "model_name") + @classmethod + def _no_identifier_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class AiAssistMappingResponse(BaseModel): """Response from LLM-assisted column classification and conversion advice.""" diff --git a/studio/backend/models/export.py b/studio/backend/models/export.py index 86ce2b05bf..6d899693b6 100644 --- a/studio/backend/models/export.py +++ b/studio/backend/models/export.py @@ -10,6 +10,13 @@ from pydantic import BaseModel, Field, field_validator from typing import List, Optional, Literal, Dict, Any +# Round 23 P1 #1 / #2 / #6: reuse the chat identifier validators +# so export requests reject newline / tab / control characters and +# URL-form ``hf_xxxxx`` tokens in any user-supplied identifier +# (Hub ``repo_id``, ``base_model_id``, the local +# ``checkpoint_path``) that flows into log lines or HF API calls. +from models.inference import _no_control_chars, _reject_embedded_hf_token + def _validate_save_directory(value: str) -> str: """Reject save_directory values that escape the export root.""" @@ -18,9 +25,17 @@ def _validate_save_directory(value: str) -> str: raw = str(value).strip() if not raw: raise ValueError("save_directory must not be empty") + # save_directory is logged verbatim by merged / base / GGUF export + # flows after resolution, so reject embedded HF tokens at the same + # boundary as the sibling identifier fields on export requests. + _reject_embedded_hf_token(raw, "save_directory") if "\x00" in raw: raise ValueError("save_directory may not contain null bytes") - if any(ch in raw for ch in ("\r", "\n")): + # Round 32 P1: reject ALL ASCII control characters (including + # TAB / VT / FF) so a caller cannot smuggle log-line breaks or + # subprocess argv splitters past the export worker. The earlier + # CR / LF check missed every other C0 byte. + if any(ord(ch) < 0x20 or ord(ch) == 0x7F for ch in raw): raise ValueError("save_directory may not contain control characters") if len(raw) > 255: raise ValueError("save_directory must be <= 255 characters") @@ -54,6 +69,19 @@ class LoadCheckpointRequest(BaseModel): description = "Allow loading models with custom code. Only enable for checkpoints/base models you trust.", ) + # Round 23 P1 #6: ``checkpoint_path`` is logged verbatim by the + # export route. Apply the same control-char + embedded-token + # rejection the chat / diffusion / training request models use. + @field_validator("checkpoint_path") + @classmethod + def _no_checkpoint_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("checkpoint_path") + @classmethod + def _no_checkpoint_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ExportStatusResponse(BaseModel): """Current export backend status.""" @@ -117,6 +145,20 @@ def _check_save_directory(cls, v): description = "HuggingFace model ID of the base model (for model card metadata)", ) + # Round 23 P1 #1: ``repo_id`` (Hub destination) and + # ``base_model_id`` (model card metadata) both feed log lines + # and the HF API. Reject control characters and URL-form + # ``hf_xxxxx`` tokens before they reach those sinks. + @field_validator("repo_id", "base_model_id") + @classmethod + def _no_identifier_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("repo_id", "base_model_id") + @classmethod + def _no_identifier_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ExportMergedModelRequest(ExportCommonOptions): """Request for exporting a merged PEFT model.""" @@ -163,6 +205,35 @@ def _check_save_directory(cls, v): description = "Hugging Face token for GGUF upload", ) + # Round 23 P1 #2: GGUF export endpoint defines its own + # ``repo_id`` (does not inherit from ExportCommonOptions), so + # the chat-style hardening needs to be applied here separately. + # ``quantization_method`` is forwarded to the export worker + # command line, so it gets the control-char check too even + # though it does not normally carry tokens. + @field_validator("repo_id") + @classmethod + def _no_repo_id_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("repo_id") + @classmethod + def _no_repo_id_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + + @field_validator("quantization_method") + @classmethod + def _no_quantization_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + # Round 30 P1 #5: quantization_method is forwarded into worker + # command lines and reflected in error / success text, so also + # reject embedded HF tokens to mirror the repo_id hardening. + @field_validator("quantization_method") + @classmethod + def _no_quantization_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ExportLoRAAdapterRequest(ExportCommonOptions): """Request for exporting only the LoRA adapter (not merged).""" diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index b5626951c4..749e5d8dc4 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -60,6 +60,29 @@ def normalize_blank_chat_template_override( return None return value + # Round 20 P1 #5: extend the diffusion-side identifier hardening + # (round 5 P2 / round 15 P1 #5) to the chat LoadRequest. Newline + # / tab / control characters in ``model_path`` or ``gguf_variant`` + # would otherwise be echoed verbatim into structured-log lines + # ("Loading model %s") and let a caller smuggle in fake log + # entries, and an embedded ``hf_...`` token in a URL-form path + # would leak the credential into the same log sinks the + # diffusion route already redacts. + @field_validator("model_path", "gguf_variant") + @classmethod + def _no_identifier_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + # Round 21 P1 #1: also reject embedded HF tokens in + # ``gguf_variant``. A caller can pass a variant string like + # ``Q4_K_M-hf_xxxxxxxx`` that flows into log sinks via the + # GGUF resolver path; without this only ``model_path`` was + # protected. + @field_validator("model_path", "gguf_variant") + @classmethod + def _no_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + cache_type_kv: Optional[str] = Field( None, description = "KV cache data type for both K and V (e.g. 'f16', 'bf16', 'q8_0', 'q4_1', 'q5_1')", @@ -104,12 +127,47 @@ def normalize_blank_chat_template_override( ), ) + # Round 28 P1 #13: each entry is forwarded verbatim to a logged + # subprocess command line and reflected in errors. Reject control + # chars and embedded HF tokens for every list entry; allow None. + @field_validator("llama_extra_args") + @classmethod + def _no_extra_args_control_chars(cls, v): + if v is None: + return v + for i, entry in enumerate(v): + _no_control_chars(entry, f"llama_extra_args[{i}]") + return v + + @field_validator("llama_extra_args") + @classmethod + def _no_extra_args_embedded_hf_tokens(cls, v): + if v is None: + return v + for i, entry in enumerate(v): + _reject_embedded_hf_token(entry, f"llama_extra_args[{i}]") + return v + class UnloadRequest(BaseModel): """Request to unload a model""" model_path: str = Field(..., description = "Model identifier to unload") + # Round 20 P1 #5: mirror the LoadRequest identifier hardening so + # /api/inference/unload also rejects control characters and + # URL-embedded HF tokens before the path reaches structured log + # sinks. + @field_validator("model_path") + @classmethod + def _no_identifier_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("model_path") + @classmethod + def _no_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ValidateModelRequest(BaseModel): """ @@ -130,6 +188,22 @@ class ValidateModelRequest(BaseModel): None, description = "GGUF quantization variant (e.g. 'Q4_K_M')" ) + # Round 20 P1 #5: same identifier hardening as LoadRequest / + # UnloadRequest. /api/inference/validate flows directly into + # ``ModelConfig.from_identifier`` and the resulting log lines, so + # control characters and embedded HF tokens must not survive. + @field_validator("model_path", "gguf_variant") + @classmethod + def _no_identifier_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + # Round 21 P1 #2: extend embedded-token rejection to + # ``gguf_variant`` here too (mirrors LoadRequest). + @field_validator("model_path", "gguf_variant") + @classmethod + def _no_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ValidateModelResponse(BaseModel): """ @@ -1421,3 +1495,193 @@ class AnthropicMessagesResponse(BaseModel): stop_reason: Optional[str] = None stop_sequence: Optional[str] = None usage: AnthropicUsage = Field(default_factory = AnthropicUsage) + + +# ── Diffusion image generation ──────────────────────────────────── + + +def _no_control_chars(value: Optional[str], field_name: str) -> Optional[str]: + """Reject newlines, tabs, and other ASCII control chars in + identifiers that get logged before HF validates them. + + Authenticated callers could otherwise inject ``\\n`` / ``\\r`` / + ``\\t`` / NUL into ``logger.info("Loading diffusion model %s", + repo_id)`` and forge fake log lines. HF repo ids and filenames + legitimately contain only ``[A-Za-z0-9._/-]``, so this is also a + useful correctness check (catches accidental ``"my repo\\n"`` + paste). Tab is included in the reject set because some logging + sinks split fields on tab; allowing it would still let an + attacker forge fake columns. + """ + if value is None: + return value + for ch in value: + if ch == "\x7f" or ord(ch) < 0x20: + raise ValueError( + f"{field_name} contains control characters; use a plain " + "Hugging Face repo / file name." + ) + return value + + +import re as _re + +_EMBEDDED_HF_TOKEN_RE = _re.compile(r"hf_[A-Za-z0-9]{20,}") + + +def _reject_embedded_hf_token(value: Optional[str], field_name: str) -> Optional[str]: + """Refuse identifiers that contain an embedded ``hf_xxx`` token. + + Round 15 P1 #5: ``repo_id`` and ``base_repo`` accept URL-style + strings (``https://hf_token@huggingface.co/owner/repo``). The + token would otherwise be stored in ``self._repo_id`` and echoed + back through ``status()`` to every authenticated browser session. + Log redaction (``_redact_hf_tokens``) covers the logger sink, but + the public status payload also needed to refuse the input. Use + the dedicated ``hf_token`` field for authentication. + """ + if value is not None and _EMBEDDED_HF_TOKEN_RE.search(value): + raise ValueError( + f"{field_name} must not embed a Hugging Face token; " + "pass it via the dedicated hf_token field instead." + ) + return value + + +class DiffusionLoadRequest(BaseModel): + """Load a diffusion image-generation model. + + repo_id is the HF repo (either GGUF-only or full diffusers layout). + gguf_filename selects the quant when repo_id is a GGUF repo. + base_repo overrides the auto-picked diffusers base used for the + VAE / text encoders when loading a GGUF-only repo. + """ + + # repo_id and base_repo are HF Hub identifiers in this release. + # Local-path support is gated behind a frontend / Tauri + # ``load-diffusion-model`` directory lease producer that has not + # shipped yet (round 32 P1 #3 in the PR reviewer trail). The + # 1024-char cap matches POSIX PATH_MAX so future local-path + # support can flip on without re-validating the field width. + repo_id: str = Field( + ..., + min_length = 1, + max_length = 1024, + description = ( + "HF repo id (owner/name). Local filesystem paths are reserved " + "for a future native-lease flow and currently rejected by the " + "route's _looks_like_local_diffusion_path guard." + ), + ) + # Round 30 P1 #4: chat /api/inference/load gates native local paths + # through a signed native_path_lease grant before the backend + # touches the filesystem. Mirror that here so /api/inference/images/ + # load cannot be used as an authenticated probe for arbitrary + # local directories. Optional; Hub ids (no leading slash / tilde) + # skip the lease check entirely. The Images UI does not yet + # surface a local-path picker, so callers that omit this field + # always get the Hub-id code path. + native_path_lease: Optional[str] = Field( + None, + description = "Frontend-visible signed native path grant for a local repo_id", + ) + gguf_filename: Optional[str] = Field( + None, + max_length = 512, + description = "GGUF filename inside repo_id (Q4_K_S, Q8_0, ...)", + ) + base_repo: Optional[str] = Field( + None, + max_length = 1024, + description = ( + "Diffusers base repo (HF id) for VAE + text encoders. Local " + "paths are gated on the same future native-lease flow as " + "repo_id." + ), + ) + base_repo_native_path_lease: Optional[str] = Field( + None, + description = "Frontend-visible signed native path grant for a local base_repo", + ) + family: Optional[str] = Field( + None, + max_length = 64, + description = "Force pipeline family: flux.2-klein | flux.2 | flux.1 | qwen-image | stable-diffusion-3 | stable-diffusion-xl", + ) + hf_token: Optional[str] = Field( + None, description = "HuggingFace token for gated models" + ) + enable_model_cpu_offload: bool = Field( + True, + description = "Offload submodules to CPU between forwards. Trades a small speed hit for ~6 GB less VRAM on FLUX-class models.", + ) + + @field_validator("repo_id", "gguf_filename", "base_repo", "family") + @classmethod + def _no_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("repo_id", "gguf_filename", "base_repo") + @classmethod + def _no_embedded_hf_tokens(cls, v, info): + # Round 17 P2 #12: ``gguf_filename`` is forwarded to the + # backend and stored on ``DiffusionBackend._gguf_filename``, + # which is later surfaced via ``status()`` / log lines. If a + # user pastes a URL-form quant path like + # ``https://hf_xxxxx@huggingface.co/.../flux.gguf`` we drop + # the embedded credential before it can leak. + return _reject_embedded_hf_token(v, info.field_name) + + +# torch.Generator.manual_seed packs into signed int64; values outside +# [-2**63, 2**63 - 1] raise ``Overflow when unpacking long long`` deep +# in the C++ layer. uint64 is also routinely cited online so accept +# any value the underlying RNG could store and bounce the rest at the +# Pydantic layer with a clean error. +_SEED_MIN = -(2**63) +_SEED_MAX = (2**64) - 1 + + +class DiffusionGenerateRequest(BaseModel): + """Generate a single image from the currently-loaded diffusion model.""" + + prompt: str = Field(..., min_length = 1, max_length = 4000) + negative_prompt: Optional[str] = Field(None, max_length = 4000) + num_inference_steps: int = Field(24, ge = 1, le = 200) + guidance_scale: float = Field(3.5, ge = 0.0, le = 20.0) + width: int = Field(1024, ge = 64, le = 2048) + height: int = Field(1024, ge = 64, le = 2048) + seed: Optional[int] = Field( + None, + ge = _SEED_MIN, + le = _SEED_MAX, + description = "Deterministic seed for reproducible outputs", + ) + + @field_validator("width", "height") + @classmethod + def _multiple_of_eight(cls, v: int) -> int: + if v % 8: + raise ValueError("width and height must be multiples of 8") + return v + + +class DiffusionGenerateResponse(BaseModel): + image_b64: str = Field(..., description = "Base64-encoded PNG") + image_mime: str = "image/png" + width: int + height: int + num_inference_steps: int + guidance_scale: float + # ``seed`` ships as a JSON number for backwards compatibility with + # the gallery and existing API consumers, but JavaScript rounds + # integers above Number.MAX_SAFE_INTEGER on JSON.parse so seeds + # bigger than 2**53 would render different from the value the + # backend actually used. ``seed_str`` is the exact decimal + # representation; the frontend reads it for reproducibility and + # falls back to ``seed`` when not supplied. + seed: Optional[int] = None + seed_str: Optional[str] = None + duration_ms: int + model: Optional[str] = None + family: Optional[str] = None diff --git a/studio/backend/models/models.py b/studio/backend/models/models.py index 46ca4e3784..3c1257d1aa 100644 --- a/studio/backend/models/models.py +++ b/studio/backend/models/models.py @@ -5,9 +5,11 @@ Pydantic schemas for Model Management API """ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing import Optional, List, Dict, Any, Literal +from models.inference import _no_control_chars, _reject_embedded_hf_token + ModelType = Literal["text", "vision", "audio", "embeddings"] @@ -206,6 +208,19 @@ class AddScanFolderRequest(BaseModel): ..., description = "Absolute or relative directory path to scan for models" ) + # path is reflected back in /scan-folders error details and logged + # via add_scan_folder_endpoint when the directory is missing, so + # apply the same identifier hardening used on other logged paths. + @field_validator("path") + @classmethod + def _no_path_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator("path") + @classmethod + def _no_path_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + class ScanFolderInfo(BaseModel): """A registered custom model scan folder.""" diff --git a/studio/backend/models/training.py b/studio/backend/models/training.py index 7c53b0fee5..4963761a9d 100644 --- a/studio/backend/models/training.py +++ b/studio/backend/models/training.py @@ -8,6 +8,13 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing import Any, Optional, List, Dict, Literal +# Round 22 P1 #1: reuse the chat / diffusion identifier validators +# so /api/training/start rejects newline / tab / control characters +# and URL-form ``hf_xxxxx`` tokens in ``model_name``. Without these +# a caller could log-line-smuggle through "Loading model %s" lines +# and leak the bearer token into structured-log sinks. +from models.inference import _no_control_chars, _reject_embedded_hf_token + _MAX_BATCH_SIZE = 4096 _MAX_GRAD_ACCUM = 4096 @@ -49,6 +56,52 @@ class TrainingStartRequest(BaseModel): model_name: str = Field( ..., description = "Model identifier (e.g., 'unsloth/llama-3-8b-bnb-4bit')" ) + + # Identifier hardening: extended progressively across analogous + # request models. format_type is copied into training_kwargs and + # written into trainer log lines, so it shares the same boundary. + @field_validator( + "model_name", + "hf_dataset", + "subset", + "train_split", + "eval_split", + "format_type", + ) + @classmethod + def _no_model_name_control_chars(cls, v, info): + return _no_control_chars(v, info.field_name) + + @field_validator( + "model_name", + "hf_dataset", + "subset", + "train_split", + "eval_split", + "format_type", + ) + @classmethod + def _no_model_name_embedded_hf_tokens(cls, v, info): + return _reject_embedded_hf_token(v, info.field_name) + + # local_datasets / local_eval_datasets are user-controlled lists + # reflected back in /api/training/start error details when + # _validate_local_dataset_paths fails, so the same control-char + + # embedded-token guards apply per entry. + @field_validator("local_datasets", "local_eval_datasets") + @classmethod + def _no_local_dataset_control_chars(cls, v, info): + for i, entry in enumerate(v or []): + _no_control_chars(entry, f"{info.field_name}[{i}]") + return v + + @field_validator("local_datasets", "local_eval_datasets") + @classmethod + def _no_local_dataset_embedded_hf_tokens(cls, v, info): + for i, entry in enumerate(v or []): + _reject_embedded_hf_token(entry, f"{info.field_name}[{i}]") + return v + training_type: Literal["LoRA/QLoRA", "Full Finetuning", "Continued Pretraining"] = ( Field( ..., diff --git a/studio/backend/requirements/no-torch-runtime.txt b/studio/backend/requirements/no-torch-runtime.txt index 85294114b1..b43d11cf36 100644 --- a/studio/backend/requirements/no-torch-runtime.txt +++ b/studio/backend/requirements/no-torch-runtime.txt @@ -43,15 +43,44 @@ safetensors>=0.4.3 datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0 accelerate>=0.34.1 peft>=0.18.0,!=0.11.0 +# Round 33 P1: reverted the round-26 hub>=1.3.0 floor to the +# pre-PR >=0.34.0 floor. Studio's install_python_stack later +# forces hub==0.36.2 via studio.txt (constrained by +# transformers==4.57.6 in extras-no-deps.txt), so the 1.3.0 +# floor was internally inconsistent. extras-no-deps holds +# transformers at 4.x, so the transformers-5.x is_offline_mode +# concern that motivated the original bump never actually +# triggers on the supported install path. +# Round 34 P1: the line itself must stay because install.sh +# --no-torch installs THIS file with --no-deps and does not run +# studio.txt afterward; without the package line a no-torch +# install ends with no huggingface_hub at all and the new +# diffusion / chat GGUF paths fail with ModuleNotFoundError. +# Verified live on B200: hub 0.36.2 + transformers 4.57.6 + +# diffusers 0.37.1 imports Flux2KleinPipeline cleanly and runs +# end-to-end image generation. huggingface_hub>=0.34.0 hf_transfer -diffusers +# Floor 0.37.0 introduces Flux2KleinPipeline + Flux2Pipeline which the +# Studio Images page imports for the default curated picker. +diffusers>=0.37.0 +# Required by diffusers.GGUFQuantizationConfig (used by the Images page +# to load FLUX.2 / FLUX.1 / Qwen-Image GGUFs from the Hub). Floor at +# 0.10.0 to match the diffusers requirement; older gguf releases raise +# at single-file load time. +gguf>=0.10.0 # Transitive deps required because this file is installed with --no-deps. # Without these, `from transformers import AutoConfig` fails at import time. regex typing_extensions filelock +# `requests` and its urllib3/charset chain are required by huggingface_hub's +# blob downloader; diffusers + GGUFQuantizationConfig 500 on first +# /api/inference/images/load otherwise. +requests +urllib3 +charset_normalizer httpx httpcore certifi diff --git a/studio/backend/requirements/studio.txt b/studio/backend/requirements/studio.txt index 96f8816b57..b360261e44 100644 --- a/studio/backend/requirements/studio.txt +++ b/studio/backend/requirements/studio.txt @@ -1,6 +1,11 @@ # Studio UI backend dependencies typer fastapi +# Required by FastAPI's multipart upload route validation +# (routes/datasets.py uploads files via UploadFile/File). Without +# this, importing the routes package raises RuntimeError on startup +# and CPU-only test environments fail before any test runs. +python-multipart uvicorn pydantic packaging @@ -18,3 +23,10 @@ diceware ddgs cryptography>=42.0.0 httpx>=0.27.0 +# Studio Images page runtime. Flux2KleinPipeline / Flux2Pipeline / +# QwenImagePipeline / StableDiffusion3Pipeline are available in +# diffusers>=0.37.0, and GGUFQuantizationConfig requires the gguf +# package (round 20 P1 #4: fresh standard Studio installs failed on +# /images/load because these were only listed in the extras files). +diffusers>=0.37.0 +gguf>=0.10.0 diff --git a/studio/backend/routes/data_recipe/seed.py b/studio/backend/routes/data_recipe/seed.py index 91cf718e6e..27bb623deb 100644 --- a/studio/backend/routes/data_recipe/seed.py +++ b/studio/backend/routes/data_recipe/seed.py @@ -433,6 +433,20 @@ async def upload_unstructured_file( tracked_ids = [fid.strip() for fid in existing_file_ids.split(",") if fid.strip()] original_filename = file.filename or "upload" + # Round 33 P1 #7: file.filename is reflected back to the client, + # persisted in the meta JSON, and echoed by error paths. Mirror + # the SeedInspectUploadRequest.filename hardening so a multipart + # upload cannot smuggle control characters or URL-form HF tokens + # through the path the JSON variant already rejects. Import + # locally to avoid a routes -> models cycle. + from models.inference import _no_control_chars, _reject_embedded_hf_token + + try: + _no_control_chars(original_filename, "filename") + _reject_embedded_hf_token(original_filename, "filename") + except ValueError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) from exc + ext = Path(original_filename).suffix.lower() if ext not in UNSTRUCTURED_ALLOWED_EXTS: raise HTTPException( diff --git a/studio/backend/routes/datasets.py b/studio/backend/routes/datasets.py index 206af2a66f..44d033b98e 100644 --- a/studio/backend/routes/datasets.py +++ b/studio/backend/routes/datasets.py @@ -68,11 +68,27 @@ def _resolve_hf_cache_realpath(repo_dir: Path) -> Optional[str]: # Import dataset utilities from utils.datasets import check_dataset_format from auth.authentication import get_current_subject +from models.inference import _no_control_chars, _reject_embedded_hf_token router = APIRouter() logger = get_logger(__name__) +def _validate_logged_identifier(value: str, field_name: str) -> str: + """Round 25 P1 #1: mirror the helper in routes/models.py so the + dataset ``/download-progress`` route never reaches logger/cache + paths with control characters or embedded HF tokens. Token-shaped + strings like ``owner/hf_abcdefghij0123456789`` would otherwise pass + the cheap ``_is_valid_repo_id`` regex and end up in warning logs. + """ + try: + value = _no_control_chars(value, field_name) + value = _reject_embedded_hf_token(value, field_name) + except ValueError as exc: + raise HTTPException(status_code = 422, detail = str(exc)) from exc + return value + + from models.datasets import ( AiAssistMappingRequest, AiAssistMappingResponse, @@ -320,7 +336,20 @@ async def upload_dataset( file: UploadFile, current_subject: str = Depends(get_current_subject), ) -> UploadDatasetResponse: - filename = _sanitize_filename(file.filename or "dataset_upload") + # Validate the raw multipart filename BEFORE sanitization so smuggled + # control characters and embedded HF tokens are rejected at the same + # boundary as the JSON path; sanitizing first would silently strip + # control chars and let raw inputs pass the validator. + raw_filename = file.filename or "dataset_upload" + from models.inference import _no_control_chars, _reject_embedded_hf_token + + try: + _no_control_chars(raw_filename, "filename") + _reject_embedded_hf_token(raw_filename, "filename") + except ValueError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) from exc + + filename = _sanitize_filename(raw_filename) ext = Path(filename).suffix.lower() if ext not in LOCAL_UPLOAD_EXTS: allowed = ", ".join(sorted(LOCAL_UPLOAD_EXTS)) @@ -370,6 +399,11 @@ async def get_dataset_download_progress( bytes are observable here. Returns ``cache_path`` so the UI can show users where the dataset blobs landed on disk. """ + # Round 25 P1 #1: harden ``repo_id`` before it reaches the + # ``logger.warning`` line at the bottom (or any future log/cache + # path). Matches ``GET /api/models/download-progress`` which + # already validates the same parameter in round 24. + repo_id = _validate_logged_identifier(repo_id, "repo_id") _empty = { "downloaded_bytes": 0, "expected_bytes": 0, diff --git a/studio/backend/routes/export.py b/studio/backend/routes/export.py index 7dbc52dbed..1a69de749a 100644 --- a/studio/backend/routes/export.py +++ b/studio/backend/routes/export.py @@ -50,6 +50,109 @@ logger = get_logger(__name__) +import contextlib + + +def _raise_if_training_active_for_export() -> None: + """409 if a training run is in flight; 503 if status check itself + raises. Mirrors the load_checkpoint guard so /export/* and /cleanup + never tear down or alter export state while training is using the + GPU. Missing core.training is treated as 'no tracker'.""" + try: + from core.training import get_training_backend # type: ignore + except Exception as e: + logger.debug("core.training not importable, skipping training guard: %s", e) + return + try: + trn = get_training_backend() + active = trn.is_training_active() + except Exception as e: + logger.warning("Could not verify training status before export op: %s", e) + raise HTTPException( + status_code = 503, + detail = ( + "Could not verify training status before the export " + "operation. Try again." + ), + ) from e + if active: + raise HTTPException( + status_code = 409, + detail = ( + "Training is currently active. Stop the training run " + "before starting an export operation." + ), + ) + + +def _raise_if_export_active_for_export() -> None: + """409 if another export job is already running; 503 if the status + check itself raises. Backends without is_export_active() are + treated as 'no tracker available' to stay compatible with mocked + backends in tests.""" + backend = get_export_backend() + is_export_active_fn = getattr(backend, "is_export_active", None) + if is_export_active_fn is None: + return + try: + export_is_active = bool(is_export_active_fn()) + except Exception as e: + logger.warning("Could not verify export status before export op: %s", e) + raise HTTPException( + status_code = 503, + detail = ( + "Could not verify export status before starting the " + "export operation. Try again." + ), + ) from e + if export_is_active: + raise HTTPException( + status_code = 409, + detail = ( + "An export job is currently active. Wait for it to " + "finish before starting another export operation." + ), + ) + + +@contextlib.asynccontextmanager +async def _export_public_window(): + """Publish the public-load window across an /export/* operation. + + backend.export_*() runs in a worker thread and does not flip + ``_export_active = True`` until the worker actually starts; during + that gap window another workload that calls ``_release_export_for`` + would see ``is_export_active() == False`` and tear down the export + subprocess. Mirror the load_checkpoint guard so the pending counter + is set for the whole export call, and the helper-busy preflight + refuses if AI Assist is mid-handoff. + + Also refuses 409 if training or another export is already active so + a queued /export/{merged,base,gguf,lora} or /cleanup cannot + double-own the GPU with a running training / export job (round 41 + consensus: load_checkpoint already runs these checks but /export/* + and /cleanup were skipping them). + """ + from routes.inference import ( + _clear_public_load_window, + _raise_if_helper_advisor_busy, + ) + + export_window_published = False + try: + _raise_if_training_active_for_export() + _raise_if_export_active_for_export() + _raise_if_helper_advisor_busy("export") + export_window_published = True + yield + finally: + if export_window_published: + try: + _clear_public_load_window("export") + except Exception: + pass + + @router.post("/load-checkpoint", response_model = ExportOperationResponse) async def load_checkpoint( request: LoadCheckpointRequest, @@ -60,50 +163,123 @@ async def load_checkpoint( Wraps ExportBackend.load_checkpoint. """ + # Round 30 P1 #8: track whether we published a public-load pending + # entry so the outer finally clears it on either success or + # failure path. + export_load_window_published = False try: # Version switching is handled automatically by the subprocess-based # export backend — no need for ensure_transformers_version() here. - # Free GPU memory: shut down any running inference/training subprocesses - # before loading the export checkpoint (they'd compete for VRAM). + # Symmetric lifecycle guard: refuse to load an export + # checkpoint while training is active so we do not silently + # terminate someone's long-running training job and possibly + # fail the export load on top of that. Mirrors the + # _raise_if_training_active checks in routes/inference.py for + # chat and /images/load. + # Run BEFORE the chat / inference / diffusion unload helpers + # below: otherwise a 409 from this guard would still leave + # the user's chat / inference / diffusion GPU owners freed + # for nothing, which is the asymmetry round 7 review #5 + # flagged. Fail-CLOSED (503) when the training backend is + # importable but its status check raises. try: - from core.inference import get_inference_backend + from core.training import get_training_backend # type: ignore + except Exception as e: + logger.debug( + "core.training not importable, skipping export training guard: %s", + e, + ) + else: + try: + trn = get_training_backend() + active = trn.is_training_active() + except Exception as e: + logger.warning( + "Could not verify training status before export load: %s", e + ) + raise HTTPException( + status_code = 503, + detail = ( + "Could not verify training status before loading " + "an export checkpoint. Try again." + ), + ) from e + if active: + raise HTTPException( + status_code = 409, + detail = ( + "Training is currently active. Stop the training " + "run before loading an export checkpoint." + ), + ) - inf = get_inference_backend() - if inf.active_model_name: - logger.info( - "Unloading inference model '%s' to free GPU memory for export", - inf.active_model_name, + backend = get_export_backend() + # Refuse to reload the export checkpoint while an export job + # is still running. ``ExportBackend.load_checkpoint`` would + # terminate the running subprocess in order to spawn a new + # one, silently corrupting the partial output the user is + # waiting on (round 13 P1 #1). Runs BEFORE the chat / + # diffusion unloads below: a 409 from this guard must not + # leave the user's chat or diffusion GPU owners freed for + # nothing (round 14 P1 #1). ``is_export_active`` may be + # absent on older / mocked backends; treat missing as "no + # async-job tracker available" and skip rather than + # fail-closed. + is_export_active_fn = getattr(backend, "is_export_active", None) + if is_export_active_fn is not None: + try: + export_is_active = bool(is_export_active_fn()) + except Exception as e: + logger.warning( + "Could not verify export status before export load: %s", e + ) + raise HTTPException( + status_code = 503, + detail = ( + "Could not verify export status before loading " + "an export checkpoint. Try again." + ), + ) from e + if export_is_active: + raise HTTPException( + status_code = 409, + detail = ( + "An export job is currently active. Stop the " + "export job before loading another checkpoint." + ), ) - inf._shutdown_subprocess() - inf.active_model_name = None - inf.models.clear() - except Exception as e: - logger.warning("Could not unload inference model: %s", e) - try: - from core.training import get_training_backend - - trn = get_training_backend() - if trn.is_training_active(): - logger.info("Stopping active training to free GPU memory for export") - trn.stop_training() - # Wait for training subprocess to actually exit before proceeding, - # otherwise it may still hold GPU memory when export tries to load. - for _ in range(60): # up to 30s - if not trn.is_training_active(): - break - import time - - time.sleep(0.5) - else: - logger.warning( - "Training subprocess did not exit within 30s, proceeding anyway" - ) - except Exception as e: - logger.warning("Could not stop training: %s", e) + # Free GPU memory: shut down any chat backend before loading + # the export checkpoint. Routes the unload through the shared + # helper so we cover llama-server is_active=True and + # safetensors loading_models -- the asymmetries round 9 + # reviews #1, #8, #9 flagged. + from routes.inference import ( + _clear_public_load_window, + _raise_if_helper_advisor_busy, + _release_chat_for, + _release_diffusion_for, + ) + + # Round 28 P1 #6: refuse before any release fires so AI Assist + # busy does not first tear down idle diffusion. + # Round 30 P1 #8: also publishes a public-load pending entry so + # a concurrent helper / advisor start cannot win the start + # lock between our snapshot and load_checkpoint flipping + # current_checkpoint / is_export_active. + _raise_if_helper_advisor_busy("export") + export_load_window_published = True + # Round 24 P1 #3: release diffusion BEFORE chat so a failing + # diffusion unload does not leave the user with no chat + # model loaded. Same reasoning as the training-start flow + # (round 18 P1 #8 / round 24 P1 #2). Earlier rounds kept the + # chat release first because the helper was best-effort; + # now that ``_release_diffusion_for`` is strict it must run + # while chat is still resident so a failure preserves it. + await _release_diffusion_for("export load") + await _release_chat_for("export") - backend = get_export_backend() # load_checkpoint spawns and waits on a subprocess and can take # minutes. Run it in a worker thread so the event loop stays # free to serve the live log SSE stream concurrently. @@ -127,6 +303,18 @@ async def load_checkpoint( status_code = 500, detail = f"Failed to load checkpoint: {str(e)}", ) + finally: + # Round 30 P1 #8: clear the public-load pending entry once the + # load attempt completes (success or failure). Skipped when + # the helper-busy check itself raised so the counter stays in + # sync with publishes. + if export_load_window_published: + try: + from routes.inference import _clear_public_load_window + except Exception: + pass + else: + _clear_public_load_window("export") @router.post("/cleanup", response_model = ExportOperationResponse) @@ -140,7 +328,12 @@ async def cleanup_export_memory( """ try: backend = get_export_backend() - success = await asyncio.to_thread(backend.cleanup_memory) + # Run the cleanup under the same public-load window /export/* + # uses so a queued export's handoff gap cannot race a cleanup + # call that tears down current_checkpoint. The window also + # refuses 409 if training or another export is in flight. + async with _export_public_window(): + success = await asyncio.to_thread(backend.cleanup_memory) if not success: raise HTTPException( @@ -211,15 +404,16 @@ async def export_merged_model( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_merged_model, - save_directory = request.save_directory, - format_type = request.format_type, - push_to_hub = request.push_to_hub, - repo_id = request.repo_id, - hf_token = request.hf_token, - private = request.private, - ) + async with _export_public_window(): + success, message, output_path = await asyncio.to_thread( + backend.export_merged_model, + save_directory = request.save_directory, + format_type = request.format_type, + push_to_hub = request.push_to_hub, + repo_id = request.repo_id, + hf_token = request.hf_token, + private = request.private, + ) if not success: raise HTTPException(status_code = 400, detail = message) @@ -251,15 +445,16 @@ async def export_base_model( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_base_model, - save_directory = request.save_directory, - push_to_hub = request.push_to_hub, - repo_id = request.repo_id, - hf_token = request.hf_token, - private = request.private, - base_model_id = request.base_model_id, - ) + async with _export_public_window(): + success, message, output_path = await asyncio.to_thread( + backend.export_base_model, + save_directory = request.save_directory, + push_to_hub = request.push_to_hub, + repo_id = request.repo_id, + hf_token = request.hf_token, + private = request.private, + base_model_id = request.base_model_id, + ) if not success: raise HTTPException(status_code = 400, detail = message) @@ -291,14 +486,15 @@ async def export_gguf( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_gguf, - save_directory = request.save_directory, - quantization_method = request.quantization_method, - push_to_hub = request.push_to_hub, - repo_id = request.repo_id, - hf_token = request.hf_token, - ) + async with _export_public_window(): + success, message, output_path = await asyncio.to_thread( + backend.export_gguf, + save_directory = request.save_directory, + quantization_method = request.quantization_method, + push_to_hub = request.push_to_hub, + repo_id = request.repo_id, + hf_token = request.hf_token, + ) if not success: raise HTTPException(status_code = 400, detail = message) @@ -330,14 +526,15 @@ async def export_lora_adapter( """ try: backend = get_export_backend() - success, message, output_path = await asyncio.to_thread( - backend.export_lora_adapter, - save_directory = request.save_directory, - push_to_hub = request.push_to_hub, - repo_id = request.repo_id, - hf_token = request.hf_token, - private = request.private, - ) + async with _export_public_window(): + success, message, output_path = await asyncio.to_thread( + backend.export_lora_adapter, + save_directory = request.save_directory, + push_to_hub = request.push_to_hub, + repo_id = request.repo_id, + hf_token = request.hf_token, + private = request.private, + ) if not success: raise HTTPException(status_code = 400, detail = message) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index bf92055929..effde958e7 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -213,6 +213,9 @@ def _friendly_error(exc: Exception) -> str: ListOpenAIContainersResponse, OpenAIContainerRequest, OpenAIContainerSummary, + DiffusionLoadRequest, + DiffusionGenerateRequest, + DiffusionGenerateResponse, ) from core.inference.anthropic_compat import ( anthropic_messages_to_openai, @@ -239,6 +242,578 @@ def _friendly_error(exc: Exception) -> str: studio_router = APIRouter() +def _raise_if_training_active(workload: str) -> None: + """Refuse a chat/diffusion/export load while training is active. + + Without this guard the load path would either (a) silently stop a + running training run via _release_other_gpu_owners_for_diffusion + or (b) double-spend VRAM and OOM both jobs. Both are worse for the + user than a 409 explaining why the request was refused. + + Failure modes are split: + * ``core.training`` cannot be imported (CI, isolated tests, + custom builds) -> silently return; nothing to protect. + * ``core.training`` is importable but ``get_training_backend()`` + or ``is_training_active()`` raises -> 503 fail-closed. We + cannot verify the GPU is free, so taking the safer route + avoids OOMing an unverifiable training run. + """ + try: + from core.training import get_training_backend # type: ignore + except Exception: + return + try: + trn = get_training_backend() + active = trn.is_training_active() + except Exception as exc: + logger.warning( + "Could not verify training status before %s load: %s", + workload, + exc, + ) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify training status before loading the " + f"{workload} model. Try again." + ), + ) from exc + if active: + raise HTTPException( + status_code = 409, + detail = ( + f"Training is currently active. Stop the training run " + f"before loading a {workload} model." + ), + ) + + +def _raise_if_export_active(workload: str) -> None: + """Refuse a chat/diffusion/training load while an export job is + actively running. + + Symmetric with ``_raise_if_training_active``: export is also a + long-running GPU-owning job a user does not want silently killed. + ONLY raises when ``is_export_active() is True`` (an export + subprocess is currently producing output). A settled + ``current_checkpoint`` is NOT an active job -- it is just held + GPU memory and gets dropped by ``_release_export_for``. + + Failure-mode split: + * ``core.export`` cannot be imported -> silently skip. + * ``get_export_backend()`` raises -> 503 fail closed. + * Backend does not expose ``is_export_active`` -> silently + skip. Older ExportBackend builds and several test mocks + only expose ``current_checkpoint``; there is no async-job + tracker for them, and forcing a 503 here would break those + flows without adding any safety they did not previously have. + * ``is_export_active()`` itself raises -> 503 fail closed + (round 10 review #7). + """ + try: + from core.export import get_export_backend # type: ignore + except Exception: + return + try: + exp = get_export_backend() + except Exception as exc: + logger.warning( + "Could not verify export backend before %s load: %s", + workload, + exc, + ) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify export status before loading the " + f"{workload} model. Try again." + ), + ) from exc + is_export_active_fn = getattr(exp, "is_export_active", None) + if is_export_active_fn is None: + return + try: + active = bool(is_export_active_fn()) + except Exception as exc: + logger.warning( + "Could not verify export status before %s load: %s", + workload, + exc, + ) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify export status before loading the " + f"{workload} model. Try again." + ), + ) from exc + if active: + raise HTTPException( + status_code = 409, + detail = ( + f"An export job is currently active. Stop the export " + f"job before loading a {workload} model." + ), + ) + + +def _raise_if_helper_advisor_busy(workload: str) -> None: + """Round 28 P1 #1 / #4 / #5 / #6: refuse a new GPU workload while + AI Assist helper / advisor still owns its PRIVATE LlamaCppBackend. + + Called early so callers do NOT first tear down idle export / + diffusion / chat owners just to fail on the helper check. + + Round 30 P1 #7-#10: also publishes a public-load pending entry + under the helper-advisor start lock so a concurrent helper start + sees the pending public owner and refuses VRAM. Callers MUST + invoke ``_clear_public_load_window(workload)`` in a paired + finally to clear the entry once the load attempt completes. + """ + try: + from utils.datasets.llm_assist import ( + _HELPER_ADVISOR_START_LOCK, + _publish_public_load_pending, + helper_advisor_busy, + public_load_pending, + ) + except Exception: + return + with _HELPER_ADVISOR_START_LOCK: + try: + busy = helper_advisor_busy() + except Exception as exc: + logger.warning( + "Could not verify helper/advisor status before %s load: %s", + workload, + exc, + ) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify AI Assist status before starting {workload}. " + f"Try again." + ), + ) from exc + if busy: + raise HTTPException( + status_code = 503, + detail = ( + f"AI Assist (helper / advisor GGUF) is still using the GPU. " + f"Wait for it to finish before starting {workload}." + ), + ) + # Round 35 P1: also refuse when another public workload is + # already mid-handoff (passed its own helper-busy snapshot + # but has not yet flipped is_training_active / + # current_checkpoint / loading_model_identifier / + # diffusion is_loading). Without this two public loads can + # both pass their idle snapshots concurrently and race + # destructive owner teardown. + if public_load_pending(): + raise HTTPException( + status_code = 503, + detail = ( + f"Another GPU workload is mid-handoff. Wait for it to " + f"finish before starting {workload}." + ), + ) + _publish_public_load_pending(workload) + + +def _clear_public_load_window(workload: str) -> None: + """Pair for ``_raise_if_helper_advisor_busy``: release the pending + public-load publish so a subsequent helper start can proceed. + Safe to call when the module import failed (no-op).""" + try: + from utils.datasets.llm_assist import _release_public_load_pending + except Exception: + return + try: + _release_public_load_pending(workload) + except Exception: + pass + + +async def _release_llama_for(workload: str) -> None: + """Unload the llama-server (GGUF) chat backend if it owns the + GPU. Treats ``is_loaded`` OR ``is_active`` OR + ``loading_model_identifier`` as held: the third covers an HF GGUF + download that has not yet flipped ``is_active`` to True (round + 13 P1 #7). Without it, /images/load, /training/start, and + /export/load could start while a long ``_download_gguf`` was in + flight; llama-server would then come up afterwards and double-own + the GPU. + + Round 16 P1 #4: a missing or unavailable llama backend is a + silent no-op (fresh install / no GGUF use), but an unload that + actually FAILS raises 503 so the caller does not start a new GPU + workload while llama-server is still resident. + """ + try: + llama = get_llama_cpp_backend() + except Exception as exc: + logger.debug("llama-server unavailable for %s: %s", workload, exc) + return + + is_loaded = bool(getattr(llama, "is_loaded", False)) + is_active = bool(getattr(llama, "is_active", False)) + is_loading = bool(getattr(llama, "loading_model_identifier", None)) + if not (is_loaded or is_active or is_loading): + return + + logger.info( + "Unloading GGUF chat (loaded=%s active=%s loading=%s) before %s load", + is_loaded, + is_active, + is_loading, + workload, + ) + try: + ok = await asyncio.to_thread(llama.unload_model) + except Exception as exc: + logger.warning("Failed to unload GGUF chat before %s load: %s", workload, exc) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not unload the existing GGUF chat model before " + f"starting {workload}." + ), + ) from exc + + # Round 28 P1 #11: a pending HF GGUF download cancelled by + # unload_model() takes up to a few seconds to settle (the load + # thread observes _cancel_event in its finally and clears + # loading_model_identifier). Wait briefly so a legitimate cancel + # does not 503. Mirrors the /api/inference/unload settling wait. + deadline = time.monotonic() + 5.0 + while ( + bool(getattr(llama, "loading_model_identifier", None)) + and time.monotonic() < deadline + ): + await asyncio.sleep(0.1) + + # Round 18 P1 #1: previously only the raised-exception path was + # treated as failure. ``llama.unload_model()`` returning ``False`` + # (subprocess refused to terminate, IPC timeout) or leaving + # ``is_loaded`` / ``is_active`` / ``loading_model_identifier`` + # populated after the call meant the next workload could allocate + # while llama-server was still resident. Re-read the same three + # fields and fail closed if anything is still set so the caller + # retries instead of double-owning VRAM. + if ( + ok is False + or bool(getattr(llama, "is_loaded", False)) + or bool(getattr(llama, "is_active", False)) + or bool(getattr(llama, "loading_model_identifier", None)) + ): + raise HTTPException( + status_code = 503, + detail = ( + "The existing GGUF chat model is still active or loading " + f"after unload; retry before starting {workload}." + ), + ) + + +async def _release_safetensors_chat_for(workload: str) -> None: + """Unload the safetensors / Unsloth chat backend (drains both + ``active_model_name`` and ``loading_models``) if it owns the GPU. + + Round 16 P1 #4: ``unload_model`` returning ``False`` (subprocess + wedged, IPC timeout) used to be silently ignored, leaving the + old chat model resident while a new GPU workload started on top. + Treat ``False`` as failure and raise 503 so the caller retries + instead of double-owning VRAM. + """ + try: + from core.inference import get_inference_backend as _gib # type: ignore + + inf = _gib() + except Exception as exc: + logger.debug("safetensors unavailable for %s: %s", workload, exc) + return + + async def _unload_required(model_name: str) -> None: + try: + ok = await asyncio.to_thread(inf.unload_model, model_name) + except Exception as exc: + raise HTTPException( + status_code = 503, + detail = ( + f"Could not unload safetensors chat model " + f"'{model_name}' before starting {workload}." + ), + ) from exc + if ok is False: + raise HTTPException( + status_code = 503, + detail = ( + f"Safetensors backend refused to unload " + f"'{model_name}' before starting {workload}. " + "Try again." + ), + ) + # Round 19 P1 #1: ``unload_model`` returning ``True`` does not + # by itself guarantee the orchestrator dropped the model. The + # worker may have responded ``unloaded`` while still holding + # weights, or a concurrent ``load`` from another tab may have + # repopulated ``loading_models`` between calls. Re-read the + # tracker fields and fail closed if this specific name is + # still active or loading so the caller retries. + remaining_loading = set(getattr(inf, "loading_models", set()) or set()) + active_after = getattr(inf, "active_model_name", None) + if active_after == model_name or model_name in remaining_loading: + raise HTTPException( + status_code = 503, + detail = ( + f"Safetensors chat model '{model_name}' is still active " + f"or loading after unload; retry before starting {workload}." + ), + ) + + active_model_name = getattr(inf, "active_model_name", None) + loading_models = set(getattr(inf, "loading_models", set()) or set()) + if active_model_name: + logger.info( + "Unloading safetensors chat '%s' before %s load", + active_model_name, + workload, + ) + await _unload_required(active_model_name) + for loading in loading_models: + if loading == active_model_name: + continue + logger.info( + "Unloading in-flight safetensors chat '%s' before %s load", + loading, + workload, + ) + await _unload_required(loading) + + # Round 21 P1 #4: final sweep without the owned_names filter. + # A concurrent ``/load`` that appeared AFTER the initial + # snapshot was previously ignored here, so a chat model that + # started loading during the unload window let the surrounding + # training / export / GGUF / diffusion start anyway. Treat ANY + # surviving active / loading entry as a failure so the caller + # retries rather than racing the new chat load for VRAM. + remaining_loading = set(getattr(inf, "loading_models", set()) or set()) + remaining_active = getattr(inf, "active_model_name", None) + if remaining_loading or remaining_active: + raise HTTPException( + status_code = 503, + detail = ( + "A safetensors chat model is still active or loading " + f"after unload; retry before starting {workload}." + ), + ) + + +async def _release_chat_for(workload: str) -> None: + """Shared 'release any GPU-owning chat backend' helper. + + Used by training / export / diffusion handoffs (which need BOTH + chat backends gone). The GGUF chat-load path uses only + ``_release_safetensors_chat_for`` because it is itself starting + llama-server -- we cannot release the backend we are about to + start. Conversely, the standard chat-load path releases only + the llama side. + """ + # Round 27 P1 #2: helper / advisor GGUF loads run on a PRIVATE + # LlamaCppBackend (round 26 P1 #1) so the global llama checks + # below do not see them. Refuse the handoff while a helper / + # advisor still owns its private backend so a new GPU workload + # does not allocate on top of helper VRAM and OOM. + try: + from utils.datasets.llm_assist import helper_advisor_busy + except Exception: + pass + else: + if helper_advisor_busy(): + raise HTTPException( + status_code = 503, + detail = ( + f"AI Assist (helper / advisor GGUF) is still using the " + f"GPU. Wait for it to finish before starting {workload}." + ), + ) + await _release_llama_for(workload) + await _release_safetensors_chat_for(workload) + + +async def _release_export_for(workload: str) -> None: + """Shared 'drop a settled export checkpoint' helper. + + ONLY shuts down the export subprocess when ``current_checkpoint`` + is set AND ``is_export_active()`` is False -- i.e. a previously + completed load is just holding GPU memory. An in-flight export + job (``is_export_active()`` True) is NEVER touched here; the + route layer is expected to refuse the workload with HTTP 409 + via ``_raise_if_export_active`` before calling this. + + Round 17 P1 #8: idle-export shutdown failures now raise HTTP 503 + instead of being swallowed, so a wedged export subprocess does + not silently leave GPU memory pinned while training / chat / + diffusion start on top. + """ + try: + from core.export import get_export_backend # type: ignore + except Exception as exc: + logger.debug("export backend unavailable for %s: %s", workload, exc) + return + + try: + exp = get_export_backend() + except Exception as exc: + logger.warning("Could not access export backend before %s: %s", workload, exc) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not access export backend before starting {workload}. " + "Try again." + ), + ) from exc + + has_checkpoint = bool(getattr(exp, "current_checkpoint", None)) + is_export_active_fn = getattr(exp, "is_export_active", None) + if is_export_active_fn is None: + active = False + else: + try: + active = bool(is_export_active_fn()) + except Exception as exc: + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify export status before starting " + f"{workload}. Try again." + ), + ) from exc + + # If an /export/* operation has published its pending window, the + # backend may not have flipped is_export_active() = True yet but the + # subprocess is mid-handoff. Refuse to tear it down so the in-flight + # export sees a stable subprocess. + try: + from utils.datasets.llm_assist import public_load_pending_for + + export_pending = public_load_pending_for("export") + except Exception: + export_pending = False + if has_checkpoint and not active and export_pending: + raise HTTPException( + status_code = 503, + detail = ( + f"Another export operation is mid-handoff. Wait for it " + f"to finish before starting {workload}." + ), + ) + + if has_checkpoint and not active: + try: + logger.info( + "Shutting down idle export (checkpoint=%s) for %s", + has_checkpoint, + workload, + ) + await asyncio.to_thread(exp._shutdown_subprocess) + except Exception as exc: + logger.warning("Could not shut down export for %s: %s", workload, exc) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not unload the idle export checkpoint before " + f"starting {workload}. Try again." + ), + ) from exc + exp.current_checkpoint = None + exp.is_vision = False + exp.is_peft = False + + +async def _release_diffusion_for(workload: str) -> None: + """Strict diffusion-unload helper for cross-workload handoffs. + + Round 17 P1 #4-7: the GGUF chat load, safetensors chat load, + training start, and export load paths each had their own + best-effort try/except around ``diff_backend.unload_model()``. + A wedged diffusion pipeline therefore stayed resident while a + new GPU workload started on top. This helper raises HTTP 503 + when the unload fails or leaves diffusion resident, so the + caller fails closed. + """ + try: + from core.inference.diffusion import get_diffusion_backend # type: ignore + except Exception as exc: + logger.debug("diffusion backend unavailable for %s: %s", workload, exc) + return + + diff_backend = get_diffusion_backend() + try: + diff_status = diff_backend.status() + except Exception as exc: + logger.warning("Could not verify diffusion status before %s: %s", workload, exc) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify diffusion status before starting " + f"{workload}. Try again." + ), + ) from exc + + if not (diff_status.get("is_loaded") or diff_status.get("is_loading")): + return + + logger.info( + "Unloading diffusion (loaded=%s loading=%s) before %s", + diff_status.get("is_loaded"), + diff_status.get("is_loading"), + workload, + ) + try: + result = await asyncio.to_thread(diff_backend.unload_model) + except Exception as exc: + logger.warning("Failed to unload diffusion before %s: %s", workload, exc) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not unload the existing diffusion image model " + f"before starting {workload}. Try again." + ), + ) from exc + + # Round 18 P1 #5: a successful pre-check status() and a + # success-shaped unload result used to mask a post-unload + # status() failure (after = {}) and let the caller proceed + # without proof that diffusion released VRAM. Fail closed + # instead so training / chat / export retry rather than + # double-owning the GPU. + try: + after = diff_backend.status() + except Exception as exc: + logger.warning( + "Could not verify diffusion status after unload before %s: %s", + workload, + exc, + ) + raise HTTPException( + status_code = 503, + detail = ( + f"Could not verify diffusion unload before starting " + f"{workload}. Try again." + ), + ) from exc + if result is False or after.get("is_loaded") or after.get("is_loading"): + raise HTTPException( + status_code = 503, + detail = ( + f"The diffusion image model is still active after unload; " + f"retry before starting {workload}." + ), + ) + + def _detect_safetensors_features(backend, chat_template: Optional[str]) -> dict: """Classify reasoning/tool capabilities via the GGUF classifier so flags match across backends. gpt-oss is overridden because Harmony @@ -583,6 +1158,10 @@ async def load_model( """ native_grant_backed = False model_log_label = request.model_path + # Round 30 P1 #7 / #9: track which branch (GGUF / safetensors) + # published a public-load pending entry so the outer finally + # decrements the same counter, even on early exception. + chat_load_window_workload: Optional[str] = None try: # Validate user-supplied llama-server pass-through args up front # so a managed-flag collision returns 400 before any model work. @@ -734,15 +1313,45 @@ async def load_model( detail = "gpu_ids is not supported for GGUF models yet.", ) - llama_backend = get_llama_cpp_backend() - unsloth_backend = get_inference_backend() + # Symmetric lifecycle guard: refuse a chat load while + # training is active. Diffusion and export paths refuse; + # without this the GGUF chat load would start llama-server + # while training still owned VRAM and double-spend it. + # Also refuse when an export job is in flight: same + # reasoning as diffusion (terminating a live export would + # corrupt the user's exported artifact). + _raise_if_training_active("chat") + _raise_if_export_active("chat") + # Round 28 P1 #1: refuse before the release helpers fire + # so we do not tear down an idle export / diffusion just to + # then 503 on the helper check. + _raise_if_helper_advisor_busy("GGUF chat") + chat_load_window_workload = "GGUF chat" + # Round 24 P1 #4: release order is now + # export -> diffusion -> safetensors chat (was + # export -> safetensors chat -> diffusion). A wedged + # diffusion unload used to fire AFTER the safetensors + # chat was already gone, so the user lost both. Drop + # the chat last so an earlier failure preserves it. + await _release_export_for("GGUF chat") + await _release_diffusion_for("GGUF chat load") - # Unload any active Unsloth model first to free VRAM - if unsloth_backend.active_model_name: - logger.info( - f"Unloading Unsloth model '{unsloth_backend.active_model_name}' before loading GGUF" - ) - unsloth_backend.unload_model(unsloth_backend.active_model_name) + llama_backend = get_llama_cpp_backend() + # Round 19 P2 #8: previously also called + # ``unsloth_backend = get_inference_backend()`` here, but + # the binding was never used in the GGUF branch. Eager + # construction makes the GGUF-only path needlessly fail + # or pay startup cost when the safetensors backend is + # unavailable / lazy-initialised; the shared + # ``_release_safetensors_chat_for`` below already + # handles missing-backend cases as a no-op. + + # Unload any safetensors / Unsloth model. Uses the shared + # helper so we also drain ``loading_models`` (round 10 + # review #4); the inline version only checked + # ``active_model_name`` and let an in-flight safetensors + # load race the new GGUF allocation. + await _release_safetensors_chat_for("GGUF chat") # Inherit llama_extra_args from the previous load when the # request omits the field (the chat-settings Apply path @@ -912,29 +1521,38 @@ async def load_model( ) # ── Standard path: load via Unsloth/transformers ────────── - backend = get_inference_backend() + # Symmetric lifecycle guard: refuse a chat load while training + # or an export is active so we do not OOM both jobs together + # and so we do not silently corrupt an in-flight export. + _raise_if_training_active("chat") + _raise_if_export_active("chat") + # Round 28 P1 #1: refuse before the release helpers tear down + # idle GPU owners. + _raise_if_helper_advisor_busy("safetensors chat") + chat_load_window_workload = "safetensors chat" + # Round 24 P1 #5: release order is now + # export -> diffusion -> llama-chat (was + # export -> llama-chat -> diffusion). A wedged diffusion + # unload used to fire AFTER the GGUF chat was already gone, + # so the user lost both. Drop llama-chat last so an earlier + # failure preserves it. + await _release_export_for("safetensors chat") + await _release_diffusion_for("safetensors chat load") - # Unload any active GGUF model first - llama_backend = get_llama_cpp_backend() - if llama_backend.is_loaded: - logger.info("Unloading GGUF model before loading Unsloth model") - llama_backend.unload_model() + backend = get_inference_backend() - # Shut down any export subprocess to free VRAM - try: - from core.export import get_export_backend + # Unload any active or mid-download llama-server. Shared + # helper so this stays in sync with the GGUF path's + # symmetric ``_release_safetensors_chat_for``. + await _release_llama_for("safetensors chat") - exp_backend = get_export_backend() - if exp_backend.current_checkpoint: - logger.info( - "Shutting down export subprocess to free GPU memory for inference" - ) - exp_backend._shutdown_subprocess() - exp_backend.current_checkpoint = None - exp_backend.is_vision = False - exp_backend.is_peft = False - except Exception as e: - logger.warning("Could not shut down export subprocess: %s", e) + # Export was already dropped above via the shared + # ``await _release_export_for("safetensors chat")`` call + # (which checks is_export_active() before the destructive + # _shutdown_subprocess). The previous inline block here + # repeated the unconditional shutdown and would terminate + # an in-flight export job; round 11 review #2 flagged the + # asymmetry. The inline block is intentionally removed. # Auto-detect quantization for LoRA adapters from adapter_config.json # The training pipeline patches this file with "unsloth_training_method" @@ -1097,6 +1715,14 @@ async def load_model( if any(h.lower() in msg.lower() for h in not_supported_hints): msg = f"This model is not supported yet. Try a different model. (Original error: {msg})" raise HTTPException(status_code = 500, detail = f"Failed to load model: {msg}") + finally: + # Round 30 P1 #7 / #9: clear whichever chat branch published a + # public-load pending entry so a subsequent helper / advisor + # start can proceed. Set on the GGUF / safetensors branches + # after _raise_if_helper_advisor_busy succeeds; stays None for + # the already-loaded fast paths above. + if chat_load_window_workload is not None: + _clear_public_load_window(chat_load_window_workload) @router.post("/validate", response_model = ValidateModelResponse) @@ -1188,23 +1814,101 @@ async def unload_model( try: # Check if the GGUF backend has this model loaded or is loading it llama_backend = get_llama_cpp_backend() - if llama_backend.is_active and ( - llama_backend.model_identifier == request.model_path - or is_registered_native_path_label( - llama_backend.model_identifier, request.model_path - ) - or not llama_backend.is_loaded - ): - llama_backend.unload_model() + loaded_identifier = getattr(llama_backend, "model_identifier", None) + loading_identifier = getattr(llama_backend, "loading_model_identifier", None) + # Round 21 P1 #3: a GGUF download that has not yet flipped + # ``is_active`` to True (model_identifier still None, + # ``loading_model_identifier`` populated) used to fall + # through to the safetensors branch, which silently + # responded ``status="unloaded"`` while llama-server kept + # downloading. Match on either the loaded OR loading + # identifier so the explicit unload route can actually + # cancel a pending GGUF load. + llama_matches_request = ( + loaded_identifier == request.model_path + or loading_identifier == request.model_path + or is_registered_native_path_label(loaded_identifier, request.model_path) + or is_registered_native_path_label(loading_identifier, request.model_path) + ) + # Round 26 P1 #5: the previous ``or not is_loaded`` fallback + # let an unload of ``owner/B`` cancel a pending llama download + # of ``owner/A`` and silently leave safetensors ``owner/B`` + # alive. Only enter the llama branch when the request actually + # matches the loaded/loading identifier, OR when llama-server + # is starting up without any identifier yet (the original + # narrow case we wanted to catch). + llama_is_starting_without_identifier = ( + getattr(llama_backend, "is_active", False) + and not getattr(llama_backend, "is_loaded", False) + and not loaded_identifier + and not loading_identifier + ) + should_unload_llama = ( + llama_matches_request + and (getattr(llama_backend, "is_active", False) or loading_identifier) + ) or llama_is_starting_without_identifier + if should_unload_llama: + # Round 19 P1 #6: previously this called + # ``llama_backend.unload_model()`` and unconditionally + # returned ``status="unloaded"`` even when the subprocess + # refused to terminate or IPC timed out. The frontend then + # showed the model as unloaded while llama-server was + # still resident. Treat ``False`` / leftover state as a + # 503 so the user retries. + ok = await asyncio.to_thread(llama_backend.unload_model) + # Round 26 P2 #15: explicit cancel of a pending GGUF load + # leaves loading_model_identifier set briefly until the + # load thread observes _cancel_event in its finally. Wait + # up to 5s so a legitimate cancel does not 503. + deadline = time.monotonic() + 5.0 + while ( + getattr(llama_backend, "loading_model_identifier", None) + and time.monotonic() < deadline + ): + await asyncio.sleep(0.1) + if ( + ok is False + or getattr(llama_backend, "is_loaded", False) + or getattr(llama_backend, "is_active", False) + or getattr(llama_backend, "loading_model_identifier", None) + ): + raise HTTPException( + status_code = 503, + detail = ( + "The GGUF model is still active or loading after unload. " + "Try again." + ), + ) logger.info(f"Unloaded GGUF model: {request.model_path}") return UnloadResponse(status = "unloaded", model = request.model_path) # Otherwise, unload from Unsloth backend backend = get_inference_backend() - backend.unload_model(request.model_path) + # Round 19 P1 #6: same fail-closed treatment for safetensors. + # ``unload_model`` returning ``False`` or leaving + # ``active_model_name`` / ``loading_models`` populated for the + # requested name must surface to the client so the UI reflects + # the real state. + ok = await asyncio.to_thread(backend.unload_model, request.model_path) + active_after = getattr(backend, "active_model_name", None) + loading_after = set(getattr(backend, "loading_models", set()) or set()) + if ( + ok is False + or active_after == request.model_path + or request.model_path in loading_after + ): + raise HTTPException( + status_code = 503, + detail = ( + "The safetensors model is still active or loading after " + "unload. Try again." + ), + ) logger.info(f"Unloaded model: {request.model_path}") return UnloadResponse(status = "unloaded", model = request.model_path) + except HTTPException: + raise except Exception as e: logger.error(f"Error unloading model: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}") @@ -1556,7 +2260,7 @@ async def generate_audio( ) try: - wav_bytes, sample_rate = await asyncio.get_event_loop().run_in_executor( + wav_bytes, sample_rate = await asyncio.get_running_loop().run_in_executor( None, gen ) except Exception as e: @@ -1584,6 +2288,366 @@ async def generate_audio( ) +# ===================================================================== +# Diffusion image generation (/images/*) +# ===================================================================== +# +# Lifecycle mirrors the GGUF chat backend: explicit load -> generate -> +# unload. Diffusion pipelines compete for the same GPU as llama-server, +# so callers on < 24 GB GPUs should unload the chat model first. + + +def _get_diffusion_backend(): + """Lazy import so non-diffusion installs do not pay the diffusers + cost at process start. The backend itself is a process-wide + singleton; reusing it across requests keeps pipeline state alive.""" + from core.inference.diffusion import get_diffusion_backend + + return get_diffusion_backend() + + +def _looks_like_local_diffusion_path(value: Optional[str]) -> bool: + """Round 30 P1 #4 / round 31 P1 #2: decide whether ``repo_id`` / + ``base_repo`` names a local filesystem path that requires a + signed ``native_path_lease`` grant. + + Hub ids on huggingface.co are strictly ``owner/repo`` -- exactly + two non-empty segments with no path-traversal parts, no weight + file suffix, and no leading separator. Anything else (absolute + paths, ``~`` / ``./`` / ``../`` prefixes, backslashes, single + segments, three-or-more-segment paths like ``exports/my-flux``, + or weight-file-shaped strings) is treated as a local-path + attempt so it cannot bypass the lease boundary by looking like + an ``owner/repo`` relative directory. + + Round 31 closes the bypass where ``DiffusionBackend.load_model`` + accepted cwd-relative directories such as ``exports/my-flux`` + that this function previously returned False for. We DO NOT + consult ``Path.exists`` so the route does not side-channel + filesystem layout via differential errors.""" + if not value: + return False + if value.startswith(("/", "~", "./", "../")): + return True + if "\\" in value: + return True + try: + candidate = Path(value).expanduser() + except (OSError, ValueError): + # Treat unparseable identifiers as local-path attempts so a + # broken input does not silently fall through to the Hub + # loader (defence-in-depth, not a tested code path). + return True + if candidate.is_absolute(): + return True + # Weight-file shaped strings ("owner/model.gguf") are not Hub + # ids; route them through the lease path so a caller cannot + # smuggle a relative file path past the repo_id field. + if value.endswith((".gguf", ".safetensors", ".bin", ".pt", ".pth")): + return True + # A canonical Hub id decomposes into exactly two non-empty, + # non-traversal segments. Anything else is invalid as a Hub id + # or path-shaped enough that DiffusionBackend.load_model would + # treat it as a local directory. + parts = value.split("/") + if len(parts) != 2 or not parts[0] or not parts[1]: + return True + if parts[0] in (".", "..") or parts[1] in (".", ".."): + return True + # Last resort: a 2-segment value like ``exports/my-flux`` passes + # all the syntactic checks above but + # ``DiffusionBackend.load_model`` would still open it as a local + # directory via ``Path(repo_id).expanduser().is_dir()``. Trigger + # the lease path for any 2-segment value that actually resolves + # to an existing local directory / file under backend CWD. This + # is a minor probe side-channel (existence of cwd-relative paths + # to an already-authenticated caller), accepted in exchange for + # closing the silent-bypass of the new lease boundary. + try: + if candidate.exists(): + return True + except (OSError, ValueError): + return True + return False + + +def _resolve_diffusion_repo_for_request( + value: Optional[str], + lease: Optional[str], + *, + operation: str, +) -> Optional[str]: + """Round 30 P1 #4: enforce the same signed-lease boundary the chat + /api/inference/load path uses. Hub ids return as-is. Local + paths require a verified ``native_path_lease`` directory grant; + a missing or invalid lease returns 400 BEFORE any GPU handoff.""" + if value is None: + return None + if not _looks_like_local_diffusion_path(value): + return value + try: + grant = verify_native_path_lease( + lease, + operation = operation, + expected_kind = "model", + expected_path_type = "directory", + ) + except NativePathLeaseError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) from exc + return str(grant.canonical_path) + + +@studio_router.post("/images/load") +async def diffusion_load( + payload: DiffusionLoadRequest, + current_subject: str = Depends(get_current_subject), +): + """Load a diffusion image-generation model. + + Pass either a full diffusers repo or a GGUF-only repo plus the + desired ``gguf_filename``. Returns the new status payload (same + shape as ``/images/status``). + """ + # Round 31 P1 #1 / #6: track whether THIS request actually + # published a public-load pending entry so the outer finally + # only clears its own publish, never another request's. The + # publish has to happen before lease resolution / backend setup, + # both of which can raise HTTPException, so the cleanup scope + # must wrap the publish too (mirrors training / export pattern). + diffusion_load_window_published = False + try: + # Refuse before the long download starts: silently stopping a + # running training run to free VRAM was the previous behavior + # and left the user with no model loaded plus a dead training + # job. Same logic for export: an export subprocess that is + # mid-flight cannot be safely terminated without corrupting + # the output, so the request is refused with 409 instead of + # silently killing it. + _raise_if_training_active("diffusion") + _raise_if_export_active("diffusion") + # Round 28 P1 #4: AI Assist helper/advisor owns a private + # llama backend invisible to + # _release_chat_backend_for_diffusion's global checks. Refuse + # early so we do not first tear down an idle export + # checkpoint just to fail on the helper check inside + # load_model. + # Round 30 P1 #10: also publishes the public-load pending + # entry so a concurrent helper start cannot win the start + # lock between our snapshot and DiffusionBackend.load_model + # flipping is_loaded. Mark the publish flag immediately so + # any failure between here and the final return clears it. + _raise_if_helper_advisor_busy("diffusion") + diffusion_load_window_published = True + # Round 30 P1 #4: enforce the signed native_path_lease + # boundary the chat load path uses so local-path repo_id / + # base_repo cannot be probed without a frontend-issued grant. + # Hub ids pass through. + resolved_repo_id = ( + _resolve_diffusion_repo_for_request( + payload.repo_id, + payload.native_path_lease, + operation = "load-diffusion-model", + ) + or payload.repo_id + ) + resolved_base_repo = _resolve_diffusion_repo_for_request( + payload.base_repo, + payload.base_repo_native_path_lease, + operation = "load-diffusion-model", + ) + # Round 18 P1 #3 + P1 #7: the route used to drop chat and + # idle export BEFORE ``backend.load_model`` ran its cheap + # validation (family inference, GGUF filename checks, + # gated-token failures, missing diffusers). A malformed image + # request would therefore unload the user's chat model and + # then return a 400 with nothing loaded; if export cleanup + # raised, chat had already been dropped. + # ``DiffusionBackend.load_model`` itself calls + # ``_release_other_gpu_owners_for_diffusion`` (strict + # idle-export shutdown after round 18 P1 #2) and + # ``_release_chat_backend_for_diffusion`` (strict GGUF + + # safetensors unload after round 17 P1 #2 + round 18 P1 #4), + # so the GPU is still freed before any allocation, just + # AFTER validation. + backend = _get_diffusion_backend() + try: + status = await asyncio.get_running_loop().run_in_executor( + None, + lambda: backend.load_model( + repo_id = resolved_repo_id, + gguf_filename = payload.gguf_filename, + base_repo = resolved_base_repo, + family_override = payload.family, + hf_token = payload.hf_token, + enable_model_cpu_offload = payload.enable_model_cpu_offload, + # Round 38 P1: this route already published the + # "diffusion" pending marker above; tell the + # backend to ignore it so the parity check it + # now applies does not self-block on our own + # publication. + ignore_public_load_pending_workload = "diffusion", + ), + ) + return JSONResponse(content = status) + except RuntimeError as exc: + # Round 15 P2 #7 / round 16 P2 #7: backend-level conflict + # checks raise RuntimeError that surfaces here. + # Distinguish: + # - "Could not verify ..." -> 503 (retryable, status + # check itself failed), matching the route-level + # pre-check. + # - explicit "currently active" -> 409 conflict. + # - anything else -> 400 (bad request). + detail = str(exc) + if ( + "Could not verify training status" in detail + or "Could not verify export status" in detail + or "Could not unload" in detail + or "refused to unload" in detail + or "still active after unload" in detail + # Round 19 P2 #7: round 18 introduced new + # RuntimeError phrasings (``still active or loading + # after unload``) that the original marker list did + # not cover, so a retryable chat-unload failure was + # returning HTTP 400 to the user instead of 503. + # Match both wordings. + or "still active or loading after unload" in detail + or "still loading after unload" in detail + # Round 28 P2 #15: AI Assist running (raised by + # _release_chat_backend_for_diffusion) is retryable. + or "AI Assist" in detail + # Backend mid-handoff race (raised by + # _raise_if_helper_advisor_busy_for_diffusion when + # another workload's public_load_pending is set) mirrors + # the route-level 503 at routes/inference.py:415, so the + # backend-surfaced phrasing must classify the same way. + or "Another GPU workload is mid-handoff" in detail + ): + # Round 17 P1 #2: chat unload failures raised by the + # backend helper map to 503 (retryable infra issue), + # matching the route-level _release_*_for helpers. + raise HTTPException(status_code = 503, detail = detail) from exc + if ( + "export job is currently active" in detail + or "Training is currently active" in detail + ): + raise HTTPException(status_code = 409, detail = detail) from exc + raise HTTPException(status_code = 400, detail = detail) from exc + except HTTPException: + raise + except Exception as exc: + logger.exception("Diffusion load failed") + raise HTTPException(status_code = 500, detail = str(exc)) + finally: + # Round 31 P1 #1 / #6: only clear when this request actually + # published. Skipped when _raise_if_training_active / + # _raise_if_export_active / _raise_if_helper_advisor_busy + # raised, so the counter stays in sync with publishes and a + # second request's failure cannot decrement a first request's + # still-active marker. + if diffusion_load_window_published: + _clear_public_load_window("diffusion") + + +@studio_router.post("/images/unload") +async def diffusion_unload( + current_subject: str = Depends(get_current_subject), +): + """Unload the current diffusion model and free GPU memory.""" + backend = _get_diffusion_backend() + # DiffusionBackend.unload_model takes _load_lock + _generate_lock + # and waits for any in-flight load / generation to complete. + # Calling it directly from an async route would freeze the + # FastAPI worker (and the SSE log stream, hardware poller, etc.) + # for the full duration of the generation. Push it onto a worker + # thread so the event loop stays responsive. + return await asyncio.to_thread(backend.unload_model) + + +@studio_router.get("/images/status") +async def diffusion_status( + current_subject: str = Depends(get_current_subject), +): + """Return diffusion backend status (loaded, family, device, etc.).""" + backend = _get_diffusion_backend() + return backend.status() + + +@studio_router.post("/images/generate", response_model = DiffusionGenerateResponse) +async def diffusion_generate( + payload: DiffusionGenerateRequest, + current_subject: str = Depends(get_current_subject), +): + """Generate a single image from the loaded diffusion model. + + Returns a base64 PNG plus the generation parameters that produced + it so the frontend can render the result and the user can reproduce + it via the same seed. + """ + backend = _get_diffusion_backend() + if not backend.is_loaded: + raise HTTPException( + status_code = 400, + detail = "No diffusion model is loaded. POST /api/inference/images/load first.", + ) + + start = time.time() + try: + from core.inference.diffusion import ( + async_generate_with_metadata, + encode_png_base64, + ) + + # ``async_generate_with_metadata`` snapshots ``model`` / + # ``family`` under the same ``_generate_lock`` that owns the + # forward, so a queued unload/load cannot replace them between + # generation end and response assembly (round 13 P2 #9). + image, meta = await async_generate_with_metadata( + backend, + prompt = payload.prompt, + negative_prompt = payload.negative_prompt, + num_inference_steps = payload.num_inference_steps, + guidance_scale = payload.guidance_scale, + width = payload.width, + height = payload.height, + seed = payload.seed, + ) + except ValueError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) + except RuntimeError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) + except Exception as exc: + logger.exception("Diffusion generation failed") + raise HTTPException(status_code = 500, detail = str(exc)) + + duration_ms = int((time.time() - start) * 1000) + # Round 29 P2 #14: FLUX-family pipelines round (width, height) to + # vae_scale_factor * 2 multiples internally, so the actual PNG can + # differ from the requested dims. Report the real image size so + # the metadata caption matches the bytes on the wire. + actual_w, actual_h = ( + image.size if hasattr(image, "size") else (payload.width, payload.height) + ) + return DiffusionGenerateResponse( + image_b64 = encode_png_base64(image), + image_mime = "image/png", + width = int(actual_w), + height = int(actual_h), + num_inference_steps = payload.num_inference_steps, + guidance_scale = payload.guidance_scale, + seed = payload.seed, + # str() of a Python int has full precision; JavaScript can + # display it via BigInt without rounding. The numeric ``seed`` + # field above is kept for backwards compatibility with older + # clients but is unsafe to use for seeds above 2**53 on the + # browser side. + seed_str = str(payload.seed) if payload.seed is not None else None, + duration_ms = duration_ms, + model = meta.get("model"), + family = meta.get("family"), + ) + + # ===================================================================== # OpenAI-Compatible Chat Completions (/chat/completions) # ===================================================================== @@ -3201,7 +4265,7 @@ async def stream_chunks(): # the second request's blocking lock acquisition would # freeze the entire event loop, stalling both streams. _DONE = object() # sentinel for generator exhaustion - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() gen = generate() while True: if cancel_event.is_set(): diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py index 9ea113e488..377d2174be 100644 --- a/studio/backend/routes/models.py +++ b/studio/backend/routes/models.py @@ -134,11 +134,30 @@ def _safe_is_dir(path) -> bool: VisionCheckResponse, EmbeddingCheckResponse, ) +from models.inference import _no_control_chars, _reject_embedded_hf_token router = APIRouter() logger = get_logger(__name__) +def _validate_logged_identifier(value: str, field_name: str) -> str: + """Round 23 P1 #7 / #8 / #9 / #10: path / query parameters that + flow into ``logger.info("... %s", value)`` lines were the last + unguarded entry points. Newline / tab / control characters let + a caller smuggle forged log entries; URL-form ``hf_xxxxx`` + tokens would leak into structured-log sinks. Mirror the + request-body validators by running both checks here and + mapping the validator's ``ValueError`` to HTTP 422 so the + client sees the same shape as a Pydantic validation failure. + """ + try: + value = _no_control_chars(value, field_name) + value = _reject_embedded_hf_token(value, field_name) + except ValueError as exc: + raise HTTPException(status_code = 422, detail = str(exc)) from exc + return value + + def derive_model_type( is_vision: bool, audio_type: Optional[str], is_embedding: bool = False ) -> ModelType: @@ -1571,6 +1590,7 @@ async def get_model_config( This endpoint wraps the backend load_model_defaults function. """ + model_name = _validate_logged_identifier(model_name, "model_name") try: if not is_local_path(model_name): resolved = resolve_cached_repo_id_case(model_name) @@ -1580,7 +1600,11 @@ async def get_model_config( resolved, model_name, ) - model_name = resolved + # Round 23 P1 #7: re-validate the cache-resolved value + # (case-only resolver should be a no-op for these + # checks, but defend in depth in case the resolver + # ever broadens its match heuristic). + model_name = _validate_logged_identifier(resolved, "model_name") logger.info(f"Getting model config for: {model_name}") from utils.models.model_config import detect_audio_type @@ -1709,6 +1733,53 @@ def _is_path_under(path: Path, root: Path) -> bool: return False +def _diffusion_owned_targets(diff_status: dict) -> list[tuple[str, str | None]]: + """Return ``(owned_repo_or_path, owned_gguf_filename)`` pairs for + every diffusion target the backend currently holds. + + Pairs the active / pending repo with the active / pending GGUF + filename (not the UI-facing collapsed ``gguf_filename``) so the + per-variant delete guards know which quant is actually owned by + each repo. Without this pairing, a swap in progress (active + ``Q4_K_S``, pending ``Q8_0``) collapsed both to the pending + variant and the active ``Q4_K_S`` GGUF could be deleted while + still mmap'd by the resident pipeline (round 13 P1 #3-5). + + Base repos are paired with ``None`` for the GGUF: the base / + component repo is loaded whole via ``from_pretrained`` and has no + per-variant delete to take advantage of. + """ + return [ + ( + diff_status.get("active_repo_id") or "", + diff_status.get("active_gguf_filename"), + ), + (diff_status.get("active_base_repo") or "", None), + ( + diff_status.get("pending_repo_id") or "", + diff_status.get("pending_gguf_filename"), + ), + (diff_status.get("pending_base_repo") or "", None), + ] + + +def _variant_delete_is_safe_for_owned_gguf( + requested_variant: str | None, + owned_gguf_filename: str | None, +) -> bool: + """True iff a per-variant delete for ``requested_variant`` against + a repo that owns ``owned_gguf_filename`` cannot remove the owned + file. + + Returns False (i.e. unsafe -> block the delete) when either + argument is missing so a NULL owned filename or a full-repo delete + (no variant) does not accidentally pass the guard.""" + if not requested_variant or not owned_gguf_filename: + return False + loaded_label = (_extract_quant_label(owned_gguf_filename.lower()) or "").lower() + return bool(loaded_label and loaded_label != requested_variant.lower()) + + def _is_path_under_lexically(path: Path, root: Path) -> bool: """Check containment without resolving the final path's symlink target.""" try: @@ -1724,7 +1795,15 @@ def _loaded_model_matches_deleted_path(active_model: str, deleted_path: Path) -> try: active = Path(active_model).expanduser().resolve() target = deleted_path.resolve() - return active == target or (target.is_dir() and active.is_relative_to(target)) + # Round 27 P1 #8: match bidirectionally so deleting a child + # directory of a loaded local model (e.g. .../my-flux/text_encoder + # while .../my-flux is loaded) also trips the guard. Mirrors + # the diffusion delete-guard pattern. + return ( + active == target + or (target.is_dir() and active.is_relative_to(target)) + or (active.is_dir() and target.is_relative_to(active)) + ) except (OSError, RuntimeError, ValueError) as e: logger.debug( "Could not resolve loaded/deleted model paths; falling back to string comparison: %s", @@ -1732,8 +1811,10 @@ def _loaded_model_matches_deleted_path(active_model: str, deleted_path: Path) -> ) active_lower = active_model.lower() target_lower = str(deleted_path).lower() - return active_lower == target_lower or active_lower.startswith( - f"{target_lower}{os.sep}" + return ( + active_lower == target_lower + or active_lower.startswith(f"{target_lower}{os.sep}") + or target_lower.startswith(f"{active_lower}{os.sep}") ) @@ -1805,6 +1886,14 @@ async def delete_finetuned_model( Only paths under Studio's outputs/exports roots are accepted. Exported GGUF entries can delete one quantization variant at a time. """ + # Round 24 P1 #7 + P2 #13: harden both ``model_path`` and + # ``gguf_variant`` for control characters and embedded HF + # tokens, mirroring the chat / diffusion / training request + # validators. Both fields end up in logger.info(...) lines. + model_path = _validate_logged_identifier(model_path, "model_path") + if gguf_variant is not None: + gguf_variant = _validate_logged_identifier(gguf_variant, "gguf_variant") + if source not in {"training", "exported"}: raise HTTPException( status_code = 400, @@ -1893,6 +1982,32 @@ async def delete_finetuned_model( from routes.inference import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() + # Pending HF GGUF download targeting this path: round 14 P1 #3. + # ``loading_model_identifier`` is set before the download starts + # and cleared after the subprocess settles, so the user cannot + # rmtree the directory llama.cpp is writing into mid-flight. + # Round 15 P1 #2: compare against ``loading_hf_variant`` (the + # variant being downloaded) rather than ``hf_variant`` (the + # PREVIOUS loaded variant, which is stale until the new load + # completes its late-metadata update). + loading_identifier = getattr(llama_backend, "loading_model_identifier", None) + loading_variant = getattr(llama_backend, "loading_hf_variant", None) + if ( + loading_identifier + and _loaded_model_matches_deleted_path( + loading_identifier, + target_path, + ) + and ( + not gguf_variant + or not loading_variant + or loading_variant.lower() == gguf_variant.lower() + ) + ): + raise HTTPException( + status_code = 409, + detail = "Cannot delete a model while it is loading", + ) if ( llama_backend.is_active and not llama_backend.is_loaded @@ -1968,6 +2083,77 @@ async def delete_finetuned_model( detail = "Could not verify model load status before deleting", ) from e + # Diffusion pipelines can also be loaded directly from a Studio + # outputs/exports path (e.g. user fine-tuned a FLUX LoRA, exported + # the merged repo locally, then loaded it via /images/load with a + # local path as repo_id). Without this guard /delete-finetuned + # could rmtree the directory the diffusion backend is reading from. + # is_loading is also blocked: status() exposes pending_repo_id / + # pending_base_repo during the load window so deletes during a + # mid-flight from_pretrained are refused. During a swap we still + # see the previous load's active_repo_id, so every owned path is + # checked rather than just the UI-facing one. + # Block both DIRECTIONS: + # * loaded path is the same as target (or a parent), and + # * loaded path is a child of target (so the user cannot rmtree + # a parent directory that contains the pipeline's mmap'd file). + # Fail-CLOSED on exception (503) like the llama.cpp / safetensors + # guards above: an unverifiable diffusion state means we cannot + # confirm the target is safe to rmtree. + try: + from core.inference.diffusion import get_diffusion_backend + + diff_backend = get_diffusion_backend() + # include_internal=True so we can iterate active_*/pending_* + # raw paths against ``target_path`` (round 16 P1 #5). + diff_status = diff_backend.status(include_internal = True) + if diff_status.get("is_loaded") or diff_status.get("is_loading"): + target_str = str(target_path) + # Pair each owned repo / path with the GGUF variant it + # actually owns (round 13 P1 #5). For a swap in flight + # (active Q4_K_S, pending Q8_0) the active variant must + # NOT be deleted just because the pending variant uses + # a different quant. + for candidate, owned_gguf in _diffusion_owned_targets(diff_status): + if not candidate: + continue + try: + candidate_resolved = Path(candidate).expanduser().resolve() + except Exception: + continue + # Relative paths (the user can do + # `/images/load repo_id=exports/my-flux`) are still + # legitimate path candidates; resolve against the + # backend cwd so they can be compared with the + # absolute ``target_path``. Round 8 review #11. + overlaps = ( + candidate_resolved == target_path + or str(candidate_resolved) == target_str + or _is_path_under(candidate_resolved, target_path) + or _is_path_under(target_path, candidate_resolved) + ) + if not overlaps: + continue + if export_type == "gguf" and _variant_delete_is_safe_for_owned_gguf( + gguf_variant, + owned_gguf, + ): + continue + raise HTTPException( + status_code = 400, + detail = "Unload the diffusion image model before deleting", + ) + except HTTPException: + raise + except Exception as e: + logger.warning( + "Could not check diffusion backend loaded model before delete: %s", e + ) + raise HTTPException( + status_code = 503, + detail = "Could not verify diffusion load status before deleting", + ) from e + try: if export_type == "gguf" and gguf_variant: if not target_path.is_dir(): @@ -2043,6 +2229,9 @@ async def get_lora_base_model( This endpoint wraps the backend get_base_model_from_lora function. """ + # Round 26 P1 #12: lora_path is echoed back in 404 detail and logs; + # harden it the same way other reflected identifiers are. + lora_path = _validate_logged_identifier(lora_path, "lora_path") try: base_model = get_base_model_from_lora(lora_path) @@ -2076,6 +2265,7 @@ async def check_vision_model( This endpoint wraps the backend is_vision_model function. """ + model_name = _validate_logged_identifier(model_name, "model_name") try: logger.info(f"Checking if vision model: {model_name}") is_vision = is_vision_model(model_name) @@ -2104,6 +2294,7 @@ async def check_embedding_model( This endpoint wraps the backend is_embedding_model function. """ + model_name = _validate_logged_identifier(model_name, "model_name") try: logger.info(f"Checking if embedding model: {model_name}") is_embedding = is_embedding_model(model_name, hf_token = hf_token) @@ -2141,6 +2332,7 @@ async def get_gguf_variants( with file sizes, whether the model supports vision, and the recommended default variant. """ + repo_id = _validate_logged_identifier(repo_id, "repo_id") try: from utils.models.model_config import is_local_path, list_local_gguf_variants @@ -2248,6 +2440,13 @@ async def get_gguf_download_progress( Tracks completed shard downloads in snapshots and in-progress downloads in the blobs directory (incomplete files). """ + # Round 28 P1 #14: mirror the hardening on the generic + # /download-progress route. Both repo_id and variant are echoed + # into the cache-scan path and can reach logs on the failure + # branch via the surrounding try/except. + repo_id = _validate_logged_identifier(repo_id, "repo_id") + if variant: + variant = _validate_logged_identifier(variant, "variant") try: if not _is_valid_repo_id(repo_id): return { @@ -2335,6 +2534,10 @@ async def get_download_progress( "progress": 0, "cache_path": None, } + # Round 24 P1 #9: ``repo_id`` flows into log lines deep in + # ``_get_repo_size_cached`` on lookup failure, so the same + # hardening the request-body models use applies here too. + repo_id = _validate_logged_identifier(repo_id, "repo_id") try: if not _is_valid_repo_id(repo_id): return _empty @@ -2598,39 +2801,283 @@ async def delete_cached_model( are removed (e.g. ``UD-Q4_K_XL``). Otherwise the entire repo is deleted. Refuses if the model is currently loaded for inference. """ + # Round 24 P1 #8 + #10: harden both ``repo_id`` and ``variant`` + # against control characters / embedded HF tokens before they + # reach logger.info(...) lines or the HF cache scan. + repo_id = _validate_logged_identifier(repo_id, "repo_id") + if variant is not None: + variant = _validate_logged_identifier(variant, "variant") if not _is_valid_repo_id(repo_id): raise HTTPException(status_code = 400, detail = "Invalid repo_id format") - # Check if model is currently loaded + # Round 25 P1 #2 / #3: round 15 added a path-ownership check to + # the diffusion guard below, but the llama.cpp and safetensors + # guards still only compared logical ``owner/repo`` strings to + # the loaded/loading identifier. If a chat or safetensors model + # was loaded via a LOCAL HF snapshot path (e.g. through the + # ``/load-local-path`` flow), the loaded identifier is the + # absolute snapshot path -- ``owner/repo`` never appears there, + # the guards passed, and ``DELETE /api/models/delete-cached`` + # could rmtree an actively mmap'd snapshot. + # + # Build the HF cache roots for ``repo_id`` ONCE up front and reuse + # them in all three guards (llama, safetensors, diffusion). Failure + # to scan the cache fails CLOSED on the assumption that we cannot + # verify ownership safely; mirrors the diffusion path-scan guard. + needle = repo_id.lower() + cache_repo_roots: list[Path] = [] + try: + for hf_cache in _all_hf_cache_scans(): + for repo_info in hf_cache.repos: + if ( + repo_info.repo_type == "model" + and repo_info.repo_id.lower() == needle + ): + try: + cache_repo_roots.append( + Path(repo_info.repo_path).expanduser().resolve() + ) + except Exception: + pass + except Exception as cache_scan_exc: + logger.warning( + "Could not scan HF cache during delete guard preflight: %s", + cache_scan_exc, + ) + raise HTTPException( + status_code = 503, + detail = ("Could not verify cache ownership before deleting. Try again."), + ) from cache_scan_exc + + def _owned_cache_path_matches(value: Optional[str], roots: list[Path]) -> bool: + """Return True if ``value`` resolves to (or contains, or is a + child of) any of the HF cache repo roots for the target repo. + Used by the llama / safetensors guards to catch local snapshot + paths the same way the diffusion guard already does. + """ + if not value or not roots: + return False + try: + owned = Path(value).expanduser().resolve() + except Exception: + return False + for root in roots: + try: + if ( + owned == root + or _is_path_under(owned, root) + or _is_path_under(root, owned) + ): + return True + except Exception: + continue + return False + + # Round 26 P1 #13 / #14: helper/advisor GGUF loads run on a + # PRIVATE LlamaCppBackend, so the global backend below cannot see + # them. utils/datasets/llm_assist.py publishes the active repo + # via helper_advisor_owns_repo() for exactly this guard. Fail + # closed on the variant question (block any variant of the repo) + # because helper/advisor flows do not pass a variant through. + try: + from utils.datasets.llm_assist import helper_advisor_owns_repo + + if helper_advisor_owns_repo(repo_id): + raise HTTPException( + status_code = 409, + detail = "Cannot delete a model while AI Assist is using it", + ) + except HTTPException: + raise + except Exception as e: + logger.warning( + "Could not check helper/advisor backend status before cache delete: %s", e + ) + raise HTTPException( + status_code = 503, + detail = "Could not verify AI Assist load status before deleting cache", + ) from e + + # Check if model is currently loaded OR loading. is_active and + # not is_loaded means an llama-server download / startup is in + # flight; the cache delete would race the hf_hub_download / mmap. + # Fail CLOSED on exception (503) like the diffusion guard below: + # unverifiable load state means we cannot confirm the delete is + # safe. try: from routes.inference import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() - if llama_backend.is_loaded and llama_backend.model_identifier: - loaded_id = llama_backend.model_identifier.lower() - if loaded_id == repo_id.lower() or loaded_id.startswith(repo_id.lower()): + loaded_id_raw = llama_backend.model_identifier or "" + loaded_id = loaded_id_raw.lower() + loading_id_raw = getattr(llama_backend, "loading_model_identifier", None) or "" + loading_id = loading_id_raw.lower() + loading_variant = ( + getattr(llama_backend, "loading_hf_variant", None) or "" + ).lower() + # Also consult the pending-load identifier: a multi-GB HF + # download stays in ``loading_model_identifier`` until the + # download completes, before ``model_identifier`` is set + # (round 13 P1 #6). Without this check the cache directory + # the download was writing into could be rmtree'd mid-flight. + # Round 16 P1 #1: pair against ``loading_hf_variant`` so a + # delete of a DIFFERENT cached quant from the same repo + # (loading Q4_K_M, deleting cached Q8_0) is allowed; only + # block when the requested variant matches what is being + # downloaded. Mirrors the /delete-finetuned pairing. + requested_variant = (variant or "").lower() + # Round 25 P1 #2: also match by HF cache snapshot path so + # local-path GGUF chat loads block the cache delete that + # owns their snapshot. + loading_matches_repo = loading_id == needle or _owned_cache_path_matches( + loading_id_raw, cache_repo_roots + ) + if loading_matches_repo: + same_loading_variant = ( + not requested_variant + or not loading_variant + or requested_variant == loading_variant + ) + if same_loading_variant: + raise HTTPException( + status_code = 409, + detail = "Cannot delete a model while it is loading", + ) + # Exact match only (case-insensitive). Prefix match would + # block deleting unrelated ``org/model`` while + # ``org/model-v2`` is loaded -- same surface the diffusion + # guard fixed in round 5. Per-variant deletes that target a + # DIFFERENT quant than the loaded one are allowed so the + # llama and diffusion paths stay symmetric (round 14 P1 #7). + loaded_matches_repo = loaded_id == needle or _owned_cache_path_matches( + loaded_id_raw, cache_repo_roots + ) + if loaded_matches_repo and ( + llama_backend.is_loaded or getattr(llama_backend, "is_active", False) + ): + loaded_variant = (getattr(llama_backend, "hf_variant", None) or "").lower() + same_variant = ( + not requested_variant + or not loaded_variant + or requested_variant == loaded_variant + ) + if same_variant: raise HTTPException( status_code = 400, detail = "Unload the model before deleting", ) except HTTPException: raise - except Exception: - pass + except Exception as e: + logger.warning( + "Could not check llama.cpp backend status before cache delete: %s", e + ) + raise HTTPException( + status_code = 503, + detail = "Could not verify llama.cpp load status before deleting cache", + ) from e try: inference_backend = get_inference_backend() - if inference_backend.active_model_name: - active = inference_backend.active_model_name.lower() - if active == repo_id.lower() or active.startswith(repo_id.lower()): + loading_models = getattr(inference_backend, "loading_models", set()) or set() + # Loading set holds model identifiers currently being + # downloaded / instantiated; treat them like active loads + # so a delete cannot race a partial mmap. + # Exact match only on the logical ``owner/repo`` side, but + # also match local snapshot paths (round 25 P1 #3) so a + # safetensors model loaded from a local HF snapshot path + # cannot have its cache rmtree'd out from under it. + for loading_model in loading_models: + ml_raw = loading_model or "" + ml = ml_raw.lower() + if ml == needle or _owned_cache_path_matches(ml_raw, cache_repo_roots): + raise HTTPException( + status_code = 409, + detail = "Cannot delete a model while it is loading", + ) + active_model_raw = inference_backend.active_model_name + if active_model_raw: + active = active_model_raw.lower() + if active == needle or _owned_cache_path_matches( + active_model_raw, cache_repo_roots + ): raise HTTPException( status_code = 400, detail = "Unload the model before deleting", ) except HTTPException: raise - except Exception: - pass + except Exception as e: + logger.warning( + "Could not check safetensors backend status before cache delete: %s", e + ) + raise HTTPException( + status_code = 503, + detail = "Could not verify safetensors load status before deleting cache", + ) from e + + # Also refuse to delete the cache underlying a loaded OR loading + # diffusion pipeline. The diffusion backend mmap's the GGUF + base + # repo weights and continues to read from the cache long after + # load; deleting them out from under it would corrupt generation. + # is_loading=True is also blocked because a mid-flight + # hf_hub_download / from_single_file would race the rmtree. + # Match exactly on repo_id (case-insensitive) instead of prefix to + # avoid blocking unrelated deletes like "org/model" while + # "org/model-v2" is loaded. + # During a swap (model A loaded, model B loading), status() + # exposes both via ``active_*`` and ``pending_*`` so we check + # every repo the backend currently owns. + # Fail-CLOSED on exception (return 503) like the neighboring + # llama.cpp / safetensors guards: we cannot verify whether the + # delete is safe, so refuse rather than risk corrupting the + # pipeline's mmap. + try: + from core.inference.diffusion import get_diffusion_backend + + diff_backend = get_diffusion_backend() + # include_internal=True so we can pair owned raw paths against + # the HF cache snapshot root (round 16 P1 #5). + diff_status = diff_backend.status(include_internal = True) + if diff_status.get("is_loaded") or diff_status.get("is_loading"): + # ``needle`` and ``cache_repo_roots`` come from the + # preflight scan above; round 25 deduplicated the + # diffusion-specific rescan and now all three guards + # share the same fail-closed cache view. + # + # Pair each owned repo with the GGUF variant it actually + # owns (active or pending) so a swap in progress does not + # collapse both quants into the pending one (round 13 + # P1 #4). Per-variant delete is still allowed if the + # requested variant differs from the variant that owns + # the matched repo. + for owned_id, owned_gguf in _diffusion_owned_targets(diff_status): + if not owned_id: + continue + owned_matches_repo = owned_id.lower() == needle + if not owned_matches_repo and _owned_cache_path_matches( + owned_id, cache_repo_roots + ): + owned_matches_repo = True + if not owned_matches_repo: + continue + if _variant_delete_is_safe_for_owned_gguf(variant, owned_gguf): + continue + raise HTTPException( + status_code = 400, + detail = "Unload the diffusion image model before deleting", + ) + except HTTPException: + raise + except Exception as e: + logger.warning( + "Could not check diffusion backend status before cache delete: %s", + e, + ) + raise HTTPException( + status_code = 503, + detail = "Could not verify diffusion load status before deleting cache", + ) from e try: cache_scans = _all_hf_cache_scans() diff --git a/studio/backend/routes/training.py b/studio/backend/routes/training.py index 6e2413b3e9..5f0bdc38a4 100644 --- a/studio/backend/routes/training.py +++ b/studio/backend/routes/training.py @@ -127,6 +127,11 @@ async def start_training( This endpoint initiates training in the background and returns immediately. Use the /status endpoint to check training progress. """ + # Round 30 P1 #7: track whether we published a public-load pending + # entry so the outer finally clears it on either success or + # failure (including any early HTTPException raised by the helper + # check itself). + training_load_window_published = False try: logger.info(f"Starting training job with model: {request.model_name}") @@ -265,37 +270,48 @@ async def start_training( ) training_kwargs["trust_remote_code"] = True - # Free GPU memory: shut down any running inference/export subprocesses - # before training starts (they'd compete for VRAM otherwise) - try: - from core.inference import get_inference_backend - - inf_backend = get_inference_backend() - if inf_backend.active_model_name: - logger.info( - "Unloading inference model '%s' to free GPU memory for training", - inf_backend.active_model_name, - ) - inf_backend._shutdown_subprocess() - inf_backend.active_model_name = None - inf_backend.models.clear() - except Exception as e: - logger.warning("Could not unload inference model: %s", e) + # Symmetric lifecycle guard: refuse to start training while + # an export job is in flight. Round 10 review #1 -- the + # previous code went straight to ``_release_export_for``, + # which would terminate the in-flight export and corrupt + # the user's output artifact. Now we 409 first; the user + # stops the export and re-submits. + from routes.inference import ( + _clear_public_load_window, + _raise_if_export_active, + _raise_if_helper_advisor_busy, + _release_chat_for, + _release_diffusion_for, + _release_export_for, + ) - try: - from core.export import get_export_backend - - exp_backend = get_export_backend() - if exp_backend.current_checkpoint: - logger.info( - "Shutting down export subprocess to free GPU memory for training" - ) - exp_backend._shutdown_subprocess() - exp_backend.current_checkpoint = None - exp_backend.is_vision = False - exp_backend.is_peft = False - except Exception as e: - logger.warning("Could not shut down export subprocess: %s", e) + _raise_if_export_active("training") + # Round 28 P1 #5: refuse before any release fires so AI Assist + # busy does not first tear down idle diffusion/export. + # Round 30 P1 #7: also publishes a public-load pending entry so + # a concurrent helper / advisor start cannot win the start + # lock between our snapshot and start_training flipping + # is_training_active. Paired clear lives in the outer + # ``finally`` below. + _raise_if_helper_advisor_busy("training") + training_load_window_published = True + # Round 18 P1 #8: release settled export FIRST so an export + # cleanup failure preserves the user's currently loaded chat + # model. The previous order (chat -> export) would drop chat + # and then refuse training when a wedged idle export raised, + # leaving the user with nothing loaded. + # Round 24 P1 #2: same reasoning extended to diffusion -> + # chat. A wedged diffusion unload used to fire AFTER the chat + # backend was already gone, so the user lost both chat and + # diffusion on a single failure mode. Order is now + # export -> diffusion -> chat, with chat as the last drop so + # earlier failures preserve it. + await _release_export_for("training") + await _release_diffusion_for("training") + await _release_chat_for("training") + + # (Diffusion release moved above chat in round 24 P1 #2; + # the old trailing call was removed to avoid double-unload.) # start_training now spawns a subprocess (non-blocking) success = backend.start_training(job_id = job_id, **training_kwargs) @@ -319,12 +335,31 @@ async def start_training( except ValueError as e: logger.warning("Rejected training GPU selection: %s", e) raise HTTPException(status_code = 400, detail = str(e)) + except HTTPException: + # Preserve the intended status code from + # _raise_if_training_active / _raise_if_export_active + # (409) and the gpu-id 400 raises above. Without this + # explicit re-raise the broad ``except Exception`` below + # converts a deliberate 409 into a 500. + raise except Exception as e: logger.error(f"Error starting training: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to start training: {str(e)}", ) + finally: + # Round 30 P1 #7: clear the public-load pending entry once the + # start attempt has finished. Skipped when the helper-busy + # check itself raised (no publish to clear) so the counter + # stays in sync with publishes. + if training_load_window_published: + try: + from routes.inference import _clear_public_load_window + except Exception: + pass + else: + _clear_public_load_window("training") @router.post("/stop", response_model = TrainingStopResponse) diff --git a/studio/backend/tests/test_diffusion_backend.py b/studio/backend/tests/test_diffusion_backend.py new file mode 100644 index 0000000000..12ed6e3038 --- /dev/null +++ b/studio/backend/tests/test_diffusion_backend.py @@ -0,0 +1,1694 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Unit tests for the diffusion image-generation backend. + +These tests cover the surface area the routes layer relies on: + +* family detection from the public Unsloth GGUF naming conventions +* generation argument validation (empty prompt, bad steps, off-grid sizes) +* base64 PNG encoding round-trips +* status() shape stays compatible with the frontend status poller +* load/unload lifecycle with the heavy diffusers import monkey-patched + +Real GPU loads are exercised manually via the Studio probe (see +``studio/backend/tests/test_diffusion_smoke.py``); here we keep the +suite CPU- and import-free so the consolidated CI job and the +``unslothai/unsloth`` CI fork can both run it on Ubuntu, macOS, and +Windows runners with no diffusion dependencies installed. +""" + +from __future__ import annotations + +import base64 +import io +import sys +import types +from types import SimpleNamespace +from typing import Any + +import pytest + + +# ── module under test ──────────────────────────────────────────── + + +@pytest.fixture(autouse = True) +def _reset_singleton(monkeypatch): + """Reset the module-level singleton between tests so each test + starts from a known state without poking globals directly.""" + import core.inference.diffusion as d + + monkeypatch.setattr(d, "_singleton", None) + yield + + +# ── family detection ──────────────────────────────────────────── + + +def test_detect_family_flux2_klein(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-klein-4B-GGUF") + assert fam is not None + assert fam.name == "flux.2-klein" + assert fam.pipeline_class == "Flux2KleinPipeline" + assert fam.transformer_class == "Flux2Transformer2DModel" + # Family default base must point to a real Hub repo (not the bare + # "FLUX.2-klein" slug that does not exist). The frontend curated + # picker still passes base_repo explicitly per size so this default + # only fires for the "custom HF repo" mode. + assert fam.base_repo == "black-forest-labs/FLUX.2-klein-base-4B" + + +def test_detect_family_flux2_dev_is_not_klein(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-dev-GGUF") + assert fam is not None + assert fam.name == "flux.2" + # Critical: FLUX.2 dev must NOT pick up the FLUX.2 klein pipeline + # because the transformer architectures and text encoder + # configurations are different. + assert fam.pipeline_class == "Flux2Pipeline" + + +def test_detect_family_flux1(): + from core.inference.diffusion import detect_family + + fam = detect_family("city96/FLUX.1-dev-gguf") + assert fam is not None + assert fam.name == "flux.1" + assert fam.pipeline_class == "FluxPipeline" + + +def test_detect_family_qwen_image(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/Qwen-Image-GGUF") + assert fam is not None + assert fam.name == "qwen-image" + + +def test_detect_family_override_wins_over_substring(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-dev-GGUF", override_family = "flux.1") + assert fam is not None + assert fam.name == "flux.1" + + +def test_detect_family_override_unknown_returns_none(): + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/FLUX.2-klein-4B-GGUF", override_family = "doesnotexist") + assert fam is None + + +def test_detect_family_unknown_returns_none(): + from core.inference.diffusion import detect_family + + assert detect_family("random/repo") is None + assert detect_family("") is None + + +def test_detect_family_sd35_is_not_sd3(): + """SD3.5 must NOT be matched as SD3 Medium. Pairing SD3.5 GGUFs + with the Medium base produces a misleading load.""" + from core.inference.diffusion import detect_family + + assert detect_family("unsloth/SD3.5-large-GGUF") is None + assert detect_family("unsloth/stable-diffusion-3.5-large-GGUF") is None + + +def test_detect_family_qwen_image_edit_is_not_qwen_image(): + """Qwen-Image-Edit must NOT be matched as Qwen-Image. The Edit + variant uses a different pipeline (image-to-image).""" + from core.inference.diffusion import detect_family + + assert detect_family("unsloth/Qwen-Image-Edit-GGUF") is None + assert detect_family("unsloth/Qwen-Image-Edit-2509-GGUF") is None + # Underscore spellings on the Hub must also be excluded; otherwise + # qwen_image_edit-GGUF silently matches the base Qwen-Image family. + assert detect_family("unsloth/qwen_image_edit-GGUF") is None + assert detect_family("unsloth/QwenImageEdit-GGUF") is None + + +def test_detect_family_finds_full_repo_sdxl(): + """SDXL lives in _FULL_REPO_FAMILIES, but the auto-detector must + still find it for ``stabilityai/stable-diffusion-xl-base-1.0`` so + the Custom HF repo entry point does not fail with 'Could not infer + a diffusion family' for the canonical SDXL repo.""" + from core.inference.diffusion import detect_family + + fam = detect_family("stabilityai/stable-diffusion-xl-base-1.0") + assert fam is not None + assert fam.name == "stable-diffusion-xl" + fam2 = detect_family("nerijs/sdxl-lora-test") + assert fam2 is not None + assert fam2.name == "stable-diffusion-xl" + + +def test_supported_families_payload_shape(): + from core.inference.diffusion import supported_families + + payload = supported_families() + assert isinstance(payload, list) + assert len(payload) >= 4 + for entry in payload: + assert set(entry.keys()) == {"name", "pipeline_class", "base_repo"} + + +# ── singleton ─────────────────────────────────────────────────── + + +def test_get_diffusion_backend_singleton(): + from core.inference.diffusion import get_diffusion_backend + + a = get_diffusion_backend() + b = get_diffusion_backend() + assert a is b + + +# ── status() shape ────────────────────────────────────────────── + + +def test_status_shape_unloaded(): + """Public status() (the browser-facing payload) must NOT contain + the guard-only ``active_*`` / ``pending_*`` fields (round 16 + P1 #5).""" + from core.inference.diffusion import get_diffusion_backend + + s = get_diffusion_backend().status() + expected_keys = { + "is_loaded", + "is_loading", + "repo_id", + "family", + "pipeline_class", + "base_repo", + "gguf_filename", + "device", + "dtype", + "loaded_at", + "last_error", + "supported_families", + } + assert expected_keys.issubset(s.keys()) + # Guard-facing fields are gated behind include_internal=True. + for guard_key in ( + "active_repo_id", + "active_base_repo", + "active_gguf_filename", + "pending_repo_id", + "pending_base_repo", + "pending_gguf_filename", + ): + assert guard_key not in s, f"public status() must not expose {guard_key}" + assert s["is_loaded"] is False + assert s["repo_id"] is None + + # Internal status() exposes the guard fields for delete/route use. + s_internal = get_diffusion_backend().status(include_internal = True) + assert s_internal["active_gguf_filename"] is None + assert s_internal["pending_gguf_filename"] is None + + +# ── encode_png_base64 ─────────────────────────────────────────── + + +def test_encode_png_base64_round_trip(): + from PIL import Image + + from core.inference.diffusion import encode_png_base64 + + img = Image.new("RGB", (16, 16), color = (255, 0, 0)) + b64 = encode_png_base64(img) + raw = base64.b64decode(b64) + decoded = Image.open(io.BytesIO(raw)) + assert decoded.format == "PNG" + assert decoded.size == (16, 16) + + +# ── generation validation (no real pipeline) ──────────────────── + + +def _stub_pipeline(monkeypatch, *, returns = None, raises = None): + """Mount a fake torch pipeline on the singleton so generate_image's + argument validation runs without diffusers / torch being involved.""" + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + + class _StubPipe: + def __call__(self, **kwargs): + if raises is not None: + raise raises + + class _Out: + pass + + o = _Out() + o.images = [ + returns + or Image.new( + "RGB", (kwargs["width"], kwargs["height"]), color = (0, 255, 0) + ) + ] + return o + + backend._pipe = _StubPipe() + backend._device = "cpu" + backend._family = d._FAMILIES[0] + backend._repo_id = "stub/stub" + return backend + + +def test_generate_image_rejects_empty_prompt(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "prompt is empty"): + backend.generate_image(prompt = " ") + + +def test_generate_image_rejects_bad_steps(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "num_inference_steps"): + backend.generate_image(prompt = "cat", num_inference_steps = 0) + with pytest.raises(ValueError, match = "num_inference_steps"): + backend.generate_image(prompt = "cat", num_inference_steps = 999) + + +def test_generate_image_rejects_off_grid_size(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "multiples of 8"): + backend.generate_image(prompt = "cat", width = 513, height = 512) + + +def test_generate_image_rejects_oversized(monkeypatch): + backend = _stub_pipeline(monkeypatch) + with pytest.raises(ValueError, match = "width and height"): + backend.generate_image(prompt = "cat", width = 4096, height = 512) + + +def test_generate_image_calls_pipeline_with_kwargs(monkeypatch): + backend = _stub_pipeline(monkeypatch) + img = backend.generate_image( + prompt = "a red sphere", + negative_prompt = "blue", + num_inference_steps = 4, + guidance_scale = 1.0, + width = 256, + height = 256, + seed = 42, + ) + assert img.size == (256, 256) + + +def test_generate_image_unloaded_raises(monkeypatch): + import core.inference.diffusion as d + + backend = d.get_diffusion_backend() + backend._pipe = None + with pytest.raises(RuntimeError, match = "No diffusion model"): + backend.generate_image(prompt = "x") + + +def test_unload_clears_state(monkeypatch): + backend = _stub_pipeline(monkeypatch) + assert backend.is_loaded + backend.unload_model() + assert not backend.is_loaded + s = backend.status() + assert s["repo_id"] is None + assert s["family"] is None + + +# ── load_model (with monkey-patched diffusers) ────────────────── + + +def _install_fake_diffusers(monkeypatch, *, raise_on_pipeline = False): + """Build a tiny ``diffusers`` shim so we can exercise load_model + without dragging the real 1+ GB diffusers / torch import in.""" + from PIL import Image + + fake = types.ModuleType("diffusers") + fake.__version__ = "fake" + + class _FakeQuantConfig: + def __init__(self, compute_dtype = None): + self.compute_dtype = compute_dtype + + class _FakeTransformer: + @classmethod + def from_single_file(cls, path, **kw): + inst = cls() + inst.path = path + inst.qc = kw.get("quantization_config") + inst.dtype = kw.get("torch_dtype") + inst.config = kw.get("config") + inst.subfolder = kw.get("subfolder") + inst.token = kw.get("token") + return inst + + class _FakePipeline: + @classmethod + def from_pretrained(cls, base_repo, **kwargs): + if raise_on_pipeline: + raise RuntimeError("simulated load failure") + inst = cls() + inst.base_repo = base_repo + inst.kwargs = kwargs + return inst + + def __call__(self, **kwargs): + class _Out: + pass + + o = _Out() + o.images = [ + Image.new("RGB", (kwargs["width"], kwargs["height"]), color = (0, 0, 255)) + ] + return o + + def enable_model_cpu_offload(self): + self.cpu_offload = True + + def to(self, device): + self.device = device + return self + + fake.GGUFQuantizationConfig = _FakeQuantConfig + fake.Flux2KleinPipeline = _FakePipeline + fake.Flux2Transformer2DModel = _FakeTransformer + fake.Flux2Pipeline = _FakePipeline + fake.FluxPipeline = _FakePipeline + fake.FluxTransformer2DModel = _FakeTransformer + fake.QwenImagePipeline = _FakePipeline + fake.QwenImageTransformer2DModel = _FakeTransformer + fake.SD3Transformer2DModel = _FakeTransformer + fake.StableDiffusion3Pipeline = _FakePipeline + fake.StableDiffusionXLPipeline = _FakePipeline + + monkeypatch.setitem(sys.modules, "diffusers", fake) + + # Pretend HF Hub gave us a local file without actually fetching. + # Round 21: accept arbitrary kwargs (round 20 preflight adds + # ``filename="model_index.json"`` and round 21 preflight adds + # ``subfolder="transformer"``) so existing tests that exercise + # the GGUF path do not hit a TypeError from the fake signature. + fake_hub = types.ModuleType("huggingface_hub") + + def _fake_download(repo_id, filename, token = None, subfolder = None, **_kwargs): + sub = f"{subfolder}/" if subfolder else "" + return f"/fake/{repo_id}/{sub}{filename}" + + fake_hub.hf_hub_download = _fake_download + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub) + + # Force CPU dtype so the test does not need CUDA. + import core.inference.diffusion as d + + monkeypatch.setattr( + d.DiffusionBackend, + "_pick_device_and_dtype", + lambda self: ("cpu", "fake_dtype"), + ) + + # Round 16 reordered _release_other_gpu_owners_for_diffusion to + # run BEFORE the chat unload. That helper imports core.training / + # core.export and raises on active or unverifiable status. Stub + # both modules with idle backends so the load_model fast path + # works in CI environments where neither module is fully wired + # (Windows runners without the training/export deps). + fake_training_mod = types.ModuleType("core.training") + fake_training_mod.get_training_backend = lambda: SimpleNamespace( + is_training_active = lambda: False, + ) + monkeypatch.setitem(sys.modules, "core.training", fake_training_mod) + + fake_export_mod = types.ModuleType("core.export") + fake_export_mod.get_export_backend = lambda: SimpleNamespace( + is_export_active = lambda: False, + current_checkpoint = None, + ) + monkeypatch.setitem(sys.modules, "core.export", fake_export_mod) + + return fake + + +def test_load_model_unknown_family(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + with pytest.raises(RuntimeError, match = "Could not infer"): + backend.load_model("private/random-repo") + + +def test_load_model_gguf_path_happy(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "flux-2-klein-4b-Q4_K_S.gguf", + ) + assert status["is_loaded"] is True + assert status["family"] == "flux.2-klein" + assert status["pipeline_class"] == "Flux2KleinPipeline" + # _smart_base_repo picks the distilled 4B (not the Base) for the + # "FLUX.2-klein-4B-GGUF" repo name. The Base variant kicks in only + # when "base" is part of the repo id. + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein-4B" + assert status["gguf_filename"] == "flux-2-klein-4b-Q4_K_S.gguf" + + +def test_load_model_recovers_after_failure(monkeypatch): + _install_fake_diffusers(monkeypatch, raise_on_pipeline = True) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + with pytest.raises(RuntimeError, match = "Failed to load diffusion model"): + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "x.gguf", + ) + # Failed load must leave the singleton unloaded but with last_error set. + s = backend.status() + assert s["is_loaded"] is False + assert s["last_error"] and "simulated load failure" in s["last_error"] + + +def test_failed_swap_clears_previous_metadata(monkeypatch): + """After a successful load, a subsequent failing load must NOT + leave status() reporting the OLD repo/family/base_repo on top of + is_loaded=false. The clear must be atomic with the pipe drop.""" + import sys + + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + # First load succeeds. + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "flux-2-klein-4b-Q4_K_S.gguf", + ) + s_before = backend.status() + assert s_before["is_loaded"] is True + assert s_before["repo_id"] == "unsloth/FLUX.2-klein-4B-GGUF" + + # Replace from_pretrained on the SAME fake module with a raising one + # without re-installing the rest of the fakes. + fake = sys.modules["diffusers"] + + def _boom(cls, *a, **kw): + raise RuntimeError("simulated swap failure") + + fake.Flux2KleinPipeline.from_pretrained = classmethod(_boom) + + with pytest.raises(RuntimeError, match = "Failed to load diffusion model"): + backend.load_model( + "unsloth/FLUX.2-dev-GGUF", + gguf_filename = "flux2-dev-Q4_K_S.gguf", + ) + + s_after = backend.status() + assert s_after["is_loaded"] is False + # Critically: stale metadata from the previous successful load + # must be cleared, not just the pipe. + assert s_after["repo_id"] is None + assert s_after["family"] is None + assert s_after["base_repo"] is None + assert s_after["gguf_filename"] is None + assert s_after["last_error"] and "simulated swap failure" in s_after["last_error"] + + +def test_load_model_swap_drops_previous(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "flux-2-klein-4b-Q4_K_S.gguf", + ) + first_pipe = backend._pipe + backend.load_model( + "unsloth/FLUX.2-dev-GGUF", + gguf_filename = "flux2-dev-Q4_K_S.gguf", + ) + assert backend._pipe is not first_pipe + assert backend.status()["family"] == "flux.2" + + +def test_load_model_base_repo_override(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-9B-GGUF", + gguf_filename = "flux-2-klein-9b-Q4_K_S.gguf", + base_repo = "black-forest-labs/FLUX.2-klein-base-9B", + ) + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein-base-9B" + + +def test_load_model_gguf_only_repo_without_filename_errors(monkeypatch): + """When the caller points at a -GGUF repo but forgets the filename, + surface a clear error instead of calling from_pretrained on the + GGUF-only repo (which 500s deep in diffusers).""" + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + with pytest.raises(RuntimeError, match = "looks like a GGUF-only repo"): + backend.load_model("unsloth/FLUX.2-klein-4B-GGUF") + + +def test_smart_base_repo_picks_9b(monkeypatch): + """For unsloth/FLUX.2-klein-9B-GGUF without an explicit base_repo, + the backend must fall through to FLUX.2-klein-9B, not the 4B base.""" + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-9B-GGUF", + gguf_filename = "flux-2-klein-9b-Q4_K_S.gguf", + ) + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein-9B" + + +def test_smart_base_repo_picks_base_9b(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-base-9B-GGUF", + gguf_filename = "flux-2-klein-base-9b-Q4_K_S.gguf", + ) + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein-base-9B" + + +def test_smart_base_repo_picks_base_4b(monkeypatch): + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "unsloth/FLUX.2-klein-base-4B-GGUF", + gguf_filename = "flux-2-klein-base-4b-Q4_K_S.gguf", + ) + assert status["base_repo"] == "black-forest-labs/FLUX.2-klein-base-4B" + + +def test_gguf_transformer_load_passes_config_subfolder_token(monkeypatch): + """Diffusers-format GGUFs require config=+subfolder= + transformer at from_single_file time; gated GGUFs also need the + token. Verify all three kwargs are forwarded.""" + fake = _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + captured: dict = {} + original = fake.Flux2Transformer2DModel.from_single_file.__func__ + + def _capture(cls, path, **kw): + captured.update(kw) + return original(cls, path, **kw) + + fake.Flux2Transformer2DModel.from_single_file = classmethod(_capture) + + backend = get_diffusion_backend() + backend.load_model( + "unsloth/FLUX.2-klein-4B-GGUF", + gguf_filename = "flux-2-klein-4b-Q4_K_S.gguf", + hf_token = "hf_test_token", + ) + assert captured.get("config") == "black-forest-labs/FLUX.2-klein-4B" + assert captured.get("subfolder") == "transformer" + assert captured.get("token") == "hf_test_token" + + +def test_release_chat_backend_calls_unload_with_model_name(monkeypatch): + """The safetensors backend unload helper must call unload_model + with the active model name (the orchestrator's signature requires + it). The previous behaviour swallowed TypeError and left the chat + model resident, defeating the lifecycle handoff.""" + import sys + import types + + fake_pkg = types.ModuleType("core.inference") + calls: list = [] + + class _Stub: + active_model_name = "owner/some-model" + + def unload_model(self, name): + calls.append(name) + self.active_model_name = None + return True + + stub = _Stub() + fake_pkg.get_inference_backend = lambda: stub + monkeypatch.setitem(sys.modules, "core.inference", fake_pkg) + + # Skip the llama-server branch by also stubbing routes.inference. + fake_routes = types.ModuleType("routes.inference") + fake_routes.get_llama_cpp_backend = lambda: types.SimpleNamespace(is_loaded = False) + monkeypatch.setitem(sys.modules, "routes.inference", fake_routes) + + from core.inference.diffusion import _release_chat_backend_for_diffusion + + _release_chat_backend_for_diffusion() + assert calls == ["owner/some-model"], calls + assert stub.active_model_name is None + + +def test_load_model_uses_safetensors_flag(monkeypatch): + """The pipeline.from_pretrained call must pass use_safetensors=True + so pickle-backed .bin weights are refused at load time.""" + fake = _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + captured: dict = {} + + original = fake.Flux2KleinPipeline.from_pretrained.__func__ + + def _capture(cls, base_repo, **kw): + captured.update(kw) + return original(cls, base_repo, **kw) + + fake.Flux2KleinPipeline.from_pretrained = classmethod(_capture) + + backend = get_diffusion_backend() + backend.load_model( + "unsloth/FLUX.2-klein-base-4B-GGUF", + gguf_filename = "flux-2-klein-base-4b-Q4_K_S.gguf", + ) + assert captured.get("use_safetensors") is True + + +def test_load_model_full_repo_does_not_substitute(monkeypatch): + """A full diffusers repo (no gguf_filename) must call from_pretrained + with the user-supplied repo, not the family default. This was the + silent-substitution bug surfaced by review.""" + fake = _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + + backend = get_diffusion_backend() + status = backend.load_model( + "owner/FLUX.1-finetune-diffusers", + family_override = "flux.1", + ) + # base_repo must echo the user repo, not the family default. + assert status["base_repo"] == "owner/FLUX.1-finetune-diffusers" + assert status["repo_id"] == "owner/FLUX.1-finetune-diffusers" + # And the fake pipeline records what we called from_pretrained with. + assert backend._pipe.base_repo == "owner/FLUX.1-finetune-diffusers" + + +def test_load_model_concurrent_serialises(monkeypatch): + """Two concurrent load_model() calls must NOT both reach + pipeline_cls.from_pretrained at the same time (race fix).""" + _install_fake_diffusers(monkeypatch) + from core.inference.diffusion import get_diffusion_backend + import threading + import time as _t + + backend = get_diffusion_backend() + active = {"n": 0, "max": 0} + lock = threading.Lock() + + import sys as _sys + + fake_pipeline_cls = _sys.modules["diffusers"].Flux2KleinPipeline + original_from_pretrained = fake_pipeline_cls.from_pretrained.__func__ + + def _instrumented_from_pretrained(cls, base_repo, **kwargs): + with lock: + active["n"] += 1 + active["max"] = max(active["max"], active["n"]) + try: + _t.sleep(0.1) + return original_from_pretrained(cls, base_repo, **kwargs) + finally: + with lock: + active["n"] -= 1 + + fake_pipeline_cls.from_pretrained = classmethod(_instrumented_from_pretrained) + + errors: list = [] + + def _do_load(): + try: + backend.load_model( + "unsloth/FLUX.2-klein-base-4B-GGUF", + gguf_filename = "flux-2-klein-base-4b-Q4_K_S.gguf", + ) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target = _do_load) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, errors + assert ( + active["max"] == 1 + ), f"Expected concurrent loads to serialise; max_active={active['max']}" + + +def test_pipe_accepts_kwarg_filter(): + """The negative_prompt filter must drop the kwarg on classes that + do not accept it (FLUX.2 / FLUX.2 klein) and keep it on the rest.""" + from core.inference.diffusion import _pipe_accepts_kwarg + + class _NoNeg: + def __call__( + self, *, prompt, num_inference_steps, guidance_scale, width, height + ): + pass + + class _Neg: + def __call__( + self, + *, + prompt, + negative_prompt = None, + num_inference_steps, + guidance_scale, + width, + height, + ): + pass + + class _VarKw: + def __call__(self, **kw): + pass + + assert _pipe_accepts_kwarg(_NoNeg(), "negative_prompt") is False + assert _pipe_accepts_kwarg(_Neg(), "negative_prompt") is True + # Anything with **kwargs is assumed to accept the kwarg (the + # alternative is to silently drop legitimate params). + assert _pipe_accepts_kwarg(_VarKw(), "negative_prompt") is True + + +def test_generate_image_strips_negative_prompt_on_flux2(monkeypatch): + """generate_image must drop negative_prompt when the loaded pipeline + does not accept it; otherwise FLUX.2 would 500 on a user-visible + field.""" + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + + received: dict = {} + + class _Flux2LikePipe: + # Signature mirrors Flux2Pipeline.__call__: NO negative_prompt. + # No **kw either, since the real FLUX.2 pipeline does not accept + # arbitrary kwargs (passing negative_prompt to it raises TypeError). + def __call__( + self, + *, + prompt, + num_inference_steps, + guidance_scale, + width, + height, + generator = None, + ): + received["prompt"] = prompt + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (width, height), (1, 2, 3))] + return o + + backend._pipe = _Flux2LikePipe() + backend._device = "cpu" + backend._family = d._FAMILIES[0] + backend._repo_id = "stub/stub" + + # If generate_image forwarded negative_prompt, the pipeline call + # would raise TypeError. The PR's filter drops it, so the call + # succeeds and we observe the prompt was still delivered. + backend.generate_image( + prompt = "a sloth", + negative_prompt = "blurry, low quality", + num_inference_steps = 4, + guidance_scale = 1.0, + width = 256, + height = 256, + ) + assert received["prompt"] == "a sloth" + + +def test_generate_image_keeps_negative_prompt_on_supporting_pipe(monkeypatch): + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + captured: dict = {} + + class _NegOK: + def __call__( + self, + *, + prompt, + negative_prompt = None, + num_inference_steps, + guidance_scale, + width, + height, + **kw, + ): + captured["negative_prompt"] = negative_prompt + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (width, height), (4, 5, 6))] + return o + + backend._pipe = _NegOK() + backend._device = "cpu" + backend._family = d._FAMILIES[2] # flux.1 supports negative_prompt + backend._repo_id = "stub/stub" + + backend.generate_image( + prompt = "a sloth", + negative_prompt = "blurry", + num_inference_steps = 4, + guidance_scale = 1.0, + width = 256, + height = 256, + ) + assert captured["negative_prompt"] == "blurry" + + +def test_generate_image_forwards_true_cfg_scale_when_supported(monkeypatch): + """When a pipeline accepts both negative_prompt and true_cfg_scale + (QwenImagePipeline, FluxPipeline) the user's guidance_scale must be + forwarded as true_cfg_scale as well, otherwise the negative prompt + is silently ignored (Qwen leaves the default true_cfg_scale=4.0 + while the user value lands on guidance_scale).""" + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + captured: dict = {} + + class _QwenLikePipe: + def __call__( + self, + *, + prompt, + negative_prompt = None, + num_inference_steps, + guidance_scale, + true_cfg_scale = 4.0, + width, + height, + **kw, + ): + captured["guidance_scale"] = guidance_scale + captured["true_cfg_scale"] = true_cfg_scale + captured["negative_prompt"] = negative_prompt + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (width, height), (7, 8, 9))] + return o + + backend._pipe = _QwenLikePipe() + backend._device = "cpu" + backend._family = d._FAMILIES[2] + backend._repo_id = "stub/stub" + + backend.generate_image( + prompt = "a sloth", + negative_prompt = "blurry", + num_inference_steps = 4, + guidance_scale = 7.5, + width = 256, + height = 256, + ) + assert captured["negative_prompt"] == "blurry" + assert captured["guidance_scale"] == 7.5 + assert captured["true_cfg_scale"] == 7.5 + + +def test_generate_image_skips_true_cfg_scale_without_negative_prompt(monkeypatch): + """Pipelines that accept true_cfg_scale must NOT have it forwarded + when no negative_prompt is given; otherwise distilled CFG models + would unintentionally switch into real-CFG mode and degrade + quality / double inference cost.""" + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + captured: dict = {} + + class _QwenLikePipe: + def __call__( + self, + *, + prompt, + negative_prompt = None, + num_inference_steps, + guidance_scale, + true_cfg_scale = 4.0, + width, + height, + **kw, + ): + captured["guidance_scale"] = guidance_scale + captured["true_cfg_scale"] = true_cfg_scale + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (width, height), (1, 1, 1))] + return o + + backend._pipe = _QwenLikePipe() + backend._device = "cpu" + backend._family = d._FAMILIES[2] + backend._repo_id = "stub/stub" + + backend.generate_image( + prompt = "a sloth", + negative_prompt = None, + num_inference_steps = 4, + guidance_scale = 7.5, + width = 256, + height = 256, + ) + assert captured["guidance_scale"] == 7.5 + # Default left untouched: real CFG only activates with neg prompt. + assert captured["true_cfg_scale"] == 4.0 + + +def test_generate_image_does_not_block_status(monkeypatch): + """status() must return promptly while a generation is in flight; + holding _lock for the whole forward froze the Images UI on the + polling endpoint for the entire (minutes long) generation.""" + import threading + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + pipe_started = threading.Event() + pipe_release = threading.Event() + + class _SlowPipe: + def __call__(self, **kw): + pipe_started.set() + # Wait until the test releases us; status() should return + # before this lock is released. + pipe_release.wait(timeout = 5) + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (kw["width"], kw["height"]), (1, 2, 3))] + return o + + backend._pipe = _SlowPipe() + backend._device = "cpu" + backend._family = d._FAMILIES[0] + backend._repo_id = "stub/stub" + + t = threading.Thread( + target = backend.generate_image, + kwargs = dict( + prompt = "a sloth", + num_inference_steps = 1, + guidance_scale = 1.0, + width = 64, + height = 64, + ), + ) + t.start() + try: + assert pipe_started.wait(timeout = 5) + # Forward is in progress; status() must not block on _lock. + completed = [False] + + def call_status(): + backend.status() + completed[0] = True + + s = threading.Thread(target = call_status) + s.start() + s.join(timeout = 2) + assert completed[0], "status() blocked on generate_image" + finally: + pipe_release.set() + t.join(timeout = 5) + + +def test_load_publishes_pending_target_during_loading(): + """status() must expose the pending repo_id / base_repo / gguf + file while is_loading=True so cache- and finetuned-delete guards + can refuse to rmtree the repo being downloaded right now. + + The pending exposure is purely a state-shape contract: load_model + sets _loading + _pending_* under _lock at the start, and status() + snapshots them under _lock. Test the contract directly instead of + racing a fake pipeline through a background thread, which was + flaky on the Windows runner (the chat-release helpers' transitive + imports of core.training.resume failed there and the load thread + exited cleanly before the main thread observed the pending state). + """ + import core.inference.diffusion as d + + backend = d.DiffusionBackend() + # Simulate the state load_model publishes at the top of its + # critical section, before from_pretrained runs. + with backend._lock: + backend._loading = True + backend._pending_repo_id = "unsloth/FLUX.2-klein-4B-GGUF" + backend._pending_base_repo = "black-forest-labs/FLUX.2-klein-4B" + backend._pending_gguf_filename = "flux-2-klein-4b-Q4_K_S.gguf" + + public = backend.status() + assert public["is_loading"] is True + assert public["repo_id"] == "unsloth/FLUX.2-klein-4B-GGUF" + assert public["base_repo"] == "black-forest-labs/FLUX.2-klein-4B" + # Guard-facing internal payload also reports the pending fields + # under their dedicated keys. + internal = backend.status(include_internal = True) + assert internal["pending_repo_id"] == "unsloth/FLUX.2-klein-4B-GGUF" + assert internal["pending_base_repo"] == "black-forest-labs/FLUX.2-klein-4B" + assert internal["pending_gguf_filename"] == "flux-2-klein-4b-Q4_K_S.gguf" + + +def test_unload_waits_for_in_flight_generation(monkeypatch): + """unload_model() must not return is_loaded=False while a + generate_image forward is still iterating; otherwise routes/... + callers see the pipe as freed while it still owns GPU memory and + can race a subsequent load.""" + import threading + import core.inference.diffusion as d + from PIL import Image + + backend = d.get_diffusion_backend() + started = threading.Event() + release = threading.Event() + generation_finished = threading.Event() + + class _SlowPipe: + def __call__(self, **kw): + started.set() + release.wait(timeout = 5) + + class _Out: + pass + + o = _Out() + o.images = [Image.new("RGB", (kw["width"], kw["height"]))] + return o + + backend._pipe = _SlowPipe() + backend._device = "cpu" + backend._family = d._FAMILIES[0] + backend._repo_id = "stub/stub" + + def do_generate(): + try: + backend.generate_image( + prompt = "x", + num_inference_steps = 1, + guidance_scale = 1.0, + width = 64, + height = 64, + ) + finally: + generation_finished.set() + + gen_thread = threading.Thread(target = do_generate) + gen_thread.start() + try: + assert started.wait(timeout = 5) + unload_returned = threading.Event() + + def do_unload(): + backend.unload_model() + unload_returned.set() + + unload_thread = threading.Thread(target = do_unload) + unload_thread.start() + # unload should block until release sets, NOT return early. + unload_thread.join(timeout = 0.5) + assert ( + not unload_returned.is_set() + ), "unload_model returned while generation was still running" + release.set() + unload_thread.join(timeout = 5) + assert unload_returned.is_set() + assert generation_finished.is_set() + finally: + release.set() + gen_thread.join(timeout = 5) + + +def test_bf16_falls_back_to_fp16_on_old_cuda(monkeypatch): + """CUDA availability does not imply BF16 support; old GPUs report + is_available()=True and is_bf16_supported()=False. The backend + must fall back to FP16 rather than picking BF16 and failing + deep inside from_pretrained.""" + import core.inference.diffusion as d + + class _FakeCuda: + @staticmethod + def is_available(): + return True + + @staticmethod + def is_bf16_supported(): + return False + + class _FakeBackends: + class mps: + @staticmethod + def is_available(): + return False + + class _FakeTorch: + cuda = _FakeCuda + backends = _FakeBackends + # Sentinel objects so the dtype identity comparison works. + bfloat16 = object() + float16 = object() + float32 = object() + + fake_torch = _FakeTorch() + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + backend = d.DiffusionBackend() + device, dtype = backend._pick_device_and_dtype() + assert device == "cuda" + assert dtype is fake_torch.float16 + + +# ── round 13 regressions ────────────────────────────────────────── + + +def test_smart_base_repo_uses_windows_leaf_only(): + """Round 13 P2 #13: a Windows path whose PARENT directory contains + 'base' must not be misclassified as the Klein Base 4B variant.""" + from core.inference.diffusion import _smart_base_repo, detect_family + + repo = r"C:\Users\me\base\FLUX.2-klein-4B-GGUF" + fam = detect_family(repo) + assert fam is not None and fam.name == "flux.2-klein" + assert _smart_base_repo(fam, repo) == "black-forest-labs/FLUX.2-klein-4B" + + +def test_resolve_local_gguf_child_rejects_traversal(tmp_path): + """Round 13 P1 #2: gguf_filename must not escape the repo root.""" + from core.inference.diffusion import _resolve_local_gguf_child + + repo_root = tmp_path / "my-flux" + repo_root.mkdir() + (repo_root / "model.gguf").write_bytes(b"x") + sibling = tmp_path / "other.gguf" + sibling.write_bytes(b"y") + + assert _resolve_local_gguf_child(repo_root, "model.gguf").name == "model.gguf" + + # ``./model.gguf`` is normalised by PurePosixPath to ``model.gguf`` + # and stays inside the repo, so it is intentionally accepted. + for bad in ("../other.gguf", "", "sub/../model.gguf"): + with pytest.raises(RuntimeError): + _resolve_local_gguf_child(repo_root, bad) + with pytest.raises(RuntimeError): + _resolve_local_gguf_child(repo_root, "/etc/passwd") + + +def test_resolve_local_gguf_child_rejects_backslash(tmp_path): + """Round 13 P1 #2: a Windows-style separator inside gguf_filename + must be rejected even on POSIX so it never becomes a literal name.""" + from core.inference.diffusion import _resolve_local_gguf_child + + repo_root = tmp_path / "my-flux" + repo_root.mkdir() + (repo_root / "model.gguf").write_bytes(b"x") + + with pytest.raises(RuntimeError): + _resolve_local_gguf_child(repo_root, r"..\\other.gguf") + + +def test_load_model_accepts_relative_local_dir(monkeypatch, tmp_path): + """Round 13 P1 #2: relative directory paths (Studio exports) must + NOT be routed through hf_hub_download.""" + import core.inference.diffusion as d + + repo_root = tmp_path / "exports" / "my-flux" + repo_root.mkdir(parents = True) + gguf_file = repo_root / "model.gguf" + gguf_file.write_bytes(b"x") + + # cwd so the relative path resolves to repo_root + monkeypatch.chdir(tmp_path) + + fake_transformer = object() + fake_pipe = SimpleNamespace( + to = lambda *a, **kw: None, + enable_model_cpu_offload = lambda: None, + ) + + class _FakeQuantConfig: + def __init__(self, **_): + pass + + class _FakeTransformerCls: + from_single_file_calls: list[tuple[str, dict]] = [] + + @classmethod + def from_single_file(cls, path, **kwargs): + cls.from_single_file_calls.append((path, kwargs)) + return fake_transformer + + class _FakePipeCls: + @classmethod + def from_pretrained(cls, base, **kwargs): + return fake_pipe + + fake_diffusers = SimpleNamespace( + __version__ = "0.99", + GGUFQuantizationConfig = _FakeQuantConfig, + Flux2Transformer2DModel = _FakeTransformerCls, + Flux2KleinPipeline = _FakePipeCls, + ) + + fake_torch = SimpleNamespace( + cuda = SimpleNamespace( + is_available = lambda: False, + is_bf16_supported = lambda: False, + empty_cache = lambda: None, + ), + bfloat16 = "bf16", + float16 = "fp16", + float32 = "fp32", + backends = SimpleNamespace( + mps = SimpleNamespace(is_available = lambda: False), + ), + ) + + def _boom(**kwargs): + # Round 20 P1 #1 added a base-repo preflight that downloads + # the diffusers ``model_index.json`` of the auto-picked + # companion repo BEFORE the chat unload. Round 21 P2 #6 + # added a second preflight for ``transformer/config.json`` + # on that same companion. Allow both preflight kinds through + # but still reject any attempt to download the GGUF itself, + # which is what this test guards. + if kwargs.get("filename") in ("model_index.json", "config.json"): + return "/tmp/preflight" + raise AssertionError("hf_hub_download must not run for a local dir") + + fake_hub = SimpleNamespace(hf_hub_download = _boom) + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub) + monkeypatch.setitem(sys.modules, "diffusers", fake_diffusers) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + backend = d.DiffusionBackend() + backend.load_model( + repo_id = "exports/my-flux", + gguf_filename = "model.gguf", + family_override = "flux.2-klein", + enable_model_cpu_offload = False, + ) + + assert _FakeTransformerCls.from_single_file_calls + resolved_path = _FakeTransformerCls.from_single_file_calls[0][0] + assert str(gguf_file.resolve()) == resolved_path + + +def test_generate_image_with_metadata_returns_active_pipeline(monkeypatch): + """Round 13 P2 #9: meta returns the resident pipeline's identity.""" + import core.inference.diffusion as d + + backend = d.DiffusionBackend() + fake_fam = d.DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2KleinTransformer3DModel", + base_repo = "black-forest-labs/FLUX.2-klein-4B", + aliases = (), + ) + + def _fake_unlocked(**kwargs): + from PIL import Image as _Image + + return _Image.new("RGB", (8, 8)) + + backend._pipe = object() + backend._repo_id = "unsloth/FLUX.2-klein-4B-GGUF" + backend._family = fake_fam + monkeypatch.setattr(backend, "_generate_image_unlocked", _fake_unlocked) + + _, meta = backend.generate_image_with_metadata(prompt = "x") + assert meta == { + "model": "unsloth/FLUX.2-klein-4B-GGUF", + "family": "flux.2-klein", + } + + +@pytest.mark.parametrize( + "repo_id", + [ + "unsloth/Qwen_Image-Edit-GGUF", + "unsloth/Qwen-Image_Edit-GGUF", + "unsloth/Qwen-ImageEdit-GGUF", + "unsloth/qwen-image_edit-2509-GGUF", + "unsloth/Qwen.Image.Edit-GGUF", + ], +) +def test_detect_family_qwen_image_edit_mixed_separators(repo_id): + """Round 14 P2 #8: every spelling of Qwen-Image-Edit must NOT + match the base Qwen-Image text-to-image family.""" + from core.inference.diffusion import detect_family + + assert detect_family(repo_id) is None + + +def test_redact_hf_tokens_removes_url_embedded_token(): + """Round 14 P2 #9: tokens embedded in user-supplied paths / + URLs must be scrubbed before logging.""" + from core.inference.diffusion import _redact_hf_tokens + + leaky = ( + "https://hf_abcdefghij0123456789@huggingface.co/unsloth/FLUX.2-klein-4B-GGUF" + ) + redacted = _redact_hf_tokens(leaky) + assert "hf_" not in redacted + assert "" in redacted + # Non-strings pass through unchanged so the helper is safe in + # logger argument lists where families / dtypes mix in. + assert _redact_hf_tokens(None) is None + assert _redact_hf_tokens(42) == 42 + + +def test_status_preserves_active_gguf_subdir(monkeypatch): + """Round 14 P1 #4: status() must surface the original caller- + supplied gguf_filename (``BF16/model.gguf``) instead of the + collapsed basename.""" + import core.inference.diffusion as d + + backend = d.DiffusionBackend() + backend._pipe = object() + backend._repo_id = "unsloth/FLUX.2-klein-4B-GGUF" + backend._gguf_path = "/cache/models/unsloth/FLUX.2-klein-4B-GGUF/BF16/model.gguf" + backend._gguf_filename = "BF16/model.gguf" + backend._family = d.DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-klein-4B", + aliases = (), + ) + + s = backend.status(include_internal = True) + assert s["active_gguf_filename"] == "BF16/model.gguf" + # UI-facing field still collapses to the basename. + assert s["gguf_filename"] == "model.gguf" + + +def test_generator_uses_cpu_when_cpu_offload_enabled(monkeypatch): + """Round 14 P1 #6: seeded CUDA generation must NOT create a + CUDA torch.Generator when the pipeline was loaded with CPU + offload enabled, otherwise it crashes mid-forward.""" + import core.inference.diffusion as d + + backend = d.DiffusionBackend() + + class _FakePipe: + def __init__(self): + self.last_kwargs = None + + def __call__(self, **kwargs): + self.last_kwargs = kwargs + from PIL import Image + + return SimpleNamespace(images = [Image.new("RGB", (8, 8))]) + + fake_pipe = _FakePipe() + backend._pipe = fake_pipe + backend._device = "cuda" + backend._cpu_offload_enabled = True + + captured_devices: list[str] = [] + + class _FakeGenerator: + def __init__(self, device): + captured_devices.append(device) + + def manual_seed(self, seed): + return self + + class _FakeTorchCuda: + @staticmethod + def is_available(): + return True + + fake_torch = SimpleNamespace(Generator = _FakeGenerator, cuda = _FakeTorchCuda) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + backend._generate_image_unlocked(prompt = "x", seed = 7, width = 8, height = 8) + assert captured_devices == ["cpu"] + + +def test_smart_base_repo_uses_windows_leaf_only_already_set_separator_round14(): + """Sanity: relative paths still work after the Windows fix.""" + from core.inference.diffusion import _smart_base_repo, detect_family + + repo = "owner/FLUX.2-klein-9B-GGUF" + fam = detect_family(repo) + assert fam is not None + assert _smart_base_repo(fam, repo) == "black-forest-labs/FLUX.2-klein-9B" + + +def test_display_repo_id_collapses_absolute_path(tmp_path): + """Round 15 P2 #6: absolute local paths must NOT leak through + status(). Hub-style repo ids pass through unchanged. Uses + ``tmp_path`` so the absolute path is platform-correct (POSIX + ``/`` paths read as drive-relative on Windows).""" + from core.inference.diffusion import _display_repo_id + + # Hub id passes through. + assert ( + _display_repo_id("black-forest-labs/FLUX.2-klein-4B") + == "black-forest-labs/FLUX.2-klein-4B" + ) + # Absolute local path collapses to leaf. ``tmp_path`` is absolute + # on every OS pytest supports. + absolute_local = tmp_path / "private-flux" + absolute_local.mkdir() + assert _display_repo_id(str(absolute_local)) == "private-flux" + # HF tokens are scrubbed defensively. + leaky = "https://hf_abcdefghij0123456789@huggingface.co/owner/repo" + out = _display_repo_id(leaky) + assert "hf_" not in out + + +def test_detect_family_rejects_substring_collisions(): + """Round 15 P2 #8: ``flux.20-model`` must NOT match ``flux.2``.""" + from core.inference.diffusion import detect_family + + # ``flux.20`` is a different number and must not collide with ``flux.2``. + assert detect_family("owner/flux.20-model") is None + # ``stable-diffusion-30`` must not match ``stable-diffusion-3``. + assert detect_family("foo/stable-diffusion-30") is None + # Legitimate ``flux.2`` still matches. + fam = detect_family("black-forest-labs/FLUX.2-dev") + assert fam is not None and fam.name == "flux.2" + + +def test_detect_family_compact_aliases_with_owner_prefix(): + """Round 16 P2 #9: compact aliases must match when the repo has + an owner prefix. ``unsloth/Flux2Klein-GGUF`` -> flux.2-klein + via the ``flux2-klein`` alias's compact form. Embedded compact + matches (e.g. ``flux2`` inside ``flux20``) must NOT match.""" + from core.inference.diffusion import detect_family + + fam = detect_family("unsloth/Flux2Klein-GGUF") + assert fam is not None and fam.name == "flux.2-klein" + # 20 is a different number; must not collide with flux.2. + assert detect_family("unsloth/Flux20-GGUF") is None + + +def test_public_status_does_not_leak_local_path_via_active_fields( + monkeypatch, tmp_path +): + """Round 16 P1 #5: even the guard-facing active_*/pending_* keys + must be absent from the public status payload. Uses ``tmp_path`` + so the absolute path is correct on every OS.""" + import core.inference.diffusion as d + + absolute_repo = tmp_path / "private-flux" + absolute_repo.mkdir() + absolute_base = tmp_path / "base-private" + absolute_base.mkdir() + + backend = d.DiffusionBackend() + backend._pipe = object() + backend._repo_id = str(absolute_repo) + backend._base_repo = str(absolute_base) + backend._family = d.DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-klein-4B", + aliases = (), + ) + + public = backend.status() + # UI-facing fields collapse to leaf and the guard-only fields are absent. + assert public["repo_id"] == "private-flux" + assert public["base_repo"] == "base-private" + for key in ( + "active_repo_id", + "active_base_repo", + "active_gguf_filename", + "pending_repo_id", + "pending_base_repo", + "pending_gguf_filename", + ): + assert key not in public + + internal = backend.status(include_internal = True) + assert internal["active_repo_id"] == str(absolute_repo) + assert internal["active_base_repo"] == str(absolute_base) + + +def test_generate_image_with_metadata_redacts_local_path(monkeypatch, tmp_path): + """Round 16 P1 #6: the generation response must not echo a raw + absolute path back to the browser.""" + import core.inference.diffusion as d + + absolute_repo = tmp_path / "secret-flux" + absolute_repo.mkdir() + + backend = d.DiffusionBackend() + backend._pipe = object() + backend._repo_id = str(absolute_repo) + backend._family = d.DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2Transformer2DModel", + base_repo = "black-forest-labs/FLUX.2-klein-4B", + aliases = (), + ) + + def _fake_unlocked(**kwargs): + from PIL import Image as _Image + + return _Image.new("RGB", (8, 8)) + + monkeypatch.setattr(backend, "_generate_image_unlocked", _fake_unlocked) + _, meta = backend.generate_image_with_metadata(prompt = "x") + assert meta["model"] == "secret-flux" + assert str(tmp_path) not in meta["model"] + + +def test_release_other_gpu_owners_raises_on_active_training(monkeypatch): + """Round 15 P1 #3: direct backend callers must not bypass the + route layer's training-active 409 guard.""" + import core.inference.diffusion as d + + fake_training_mod = types.ModuleType("core.training") + fake_training_mod.get_training_backend = lambda: SimpleNamespace( + is_training_active = lambda: True + ) + monkeypatch.setitem(sys.modules, "core.training", fake_training_mod) + + # Ensure export module import does not fail the test before the + # training raise lands. + fake_export_mod = types.ModuleType("core.export") + fake_export_mod.get_export_backend = lambda: SimpleNamespace( + is_export_active = lambda: False, + current_checkpoint = None, + ) + monkeypatch.setitem(sys.modules, "core.export", fake_export_mod) + + with pytest.raises(RuntimeError) as exc_info: + d._release_other_gpu_owners_for_diffusion() + assert "Training is currently active" in str(exc_info.value) + + +def test_generate_image_with_metadata_blocks_concurrent_unload(monkeypatch): + """Round 13 P2 #9: _generate_lock serialises the forward AND the + meta snapshot, so a queued unload cannot wipe state in between.""" + import threading + import core.inference.diffusion as d + + backend = d.DiffusionBackend() + fake_fam = d.DiffusionFamily( + name = "flux.2-klein", + pipeline_class = "Flux2KleinPipeline", + transformer_class = "Flux2KleinTransformer3DModel", + base_repo = "black-forest-labs/FLUX.2-klein-4B", + aliases = (), + ) + + started = threading.Event() + finish = threading.Event() + + def _fake_unlocked(**kwargs): + from PIL import Image as _Image + + started.set() + # Hold long enough for the unload thread to race the metadata + # snapshot if the lock were released too early. + finish.wait(timeout = 2.0) + return _Image.new("RGB", (8, 8)) + + backend._pipe = object() + backend._repo_id = "unsloth/FLUX.2-klein-4B-GGUF" + backend._family = fake_fam + monkeypatch.setattr(backend, "_generate_image_unlocked", _fake_unlocked) + + result: list = [] + + def _gen(): + result.append(backend.generate_image_with_metadata(prompt = "x")) + + gen_thread = threading.Thread(target = _gen) + gen_thread.start() + assert started.wait(timeout = 2.0) + + def _unload(): + backend.unload_model() + + un_thread = threading.Thread(target = _unload) + un_thread.start() + # The unload must NOT have completed yet; it queues behind the + # generation's _generate_lock. + un_thread.join(timeout = 0.2) + assert un_thread.is_alive() + finish.set() + gen_thread.join(timeout = 5.0) + un_thread.join(timeout = 5.0) + + assert result + _, meta = result[0] + assert meta["model"] == "unsloth/FLUX.2-klein-4B-GGUF" + assert meta["family"] == "flux.2-klein" diff --git a/studio/backend/tests/test_diffusion_routes.py b/studio/backend/tests/test_diffusion_routes.py new file mode 100644 index 0000000000..41869d4044 --- /dev/null +++ b/studio/backend/tests/test_diffusion_routes.py @@ -0,0 +1,336 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""Route-level tests for ``/api/inference/images/*``. + +Mounts the actual ``inference_router`` on a fresh FastAPI app with the +auth dependency replaced by a stub so we exercise the same FastAPI +handlers Studio ships in production. The diffusion backend is replaced +with an in-memory stub so we don't need diffusers / GPUs to run these. + +To stay runnable in a minimal CPU-only env, ``routes/inference.py`` +is loaded directly via ``importlib`` so we do NOT trigger +``routes/__init__.py`` -- that file eagerly imports training / +datasets / data_recipe / export and would drag in heavy deps +(matplotlib, etc.) that the diffusion tests do not need. +""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from PIL import Image + + +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(_BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(_BACKEND_ROOT)) + + +def _import_inference_module(): + """Load ``routes/inference.py`` without executing ``routes/__init__``. + + The package init imports training / datasets / data_recipe / export + routers, which pull in matplotlib / pandas / training stack. The + diffusion tests only need the inference module so we side-step the + package import via importlib.spec_from_file_location. + """ + # If a previous test already imported routes the normal way, reuse + # the cached module instead of re-loading. + cached = sys.modules.get("routes.inference") + if cached is not None: + return cached + target = _BACKEND_ROOT / "routes" / "inference.py" + spec = importlib.util.spec_from_file_location( + "routes.inference", + target, + # We do NOT set submodule_search_locations for routes itself + # because that would re-trigger routes/__init__.py. The module + # uses relative imports sparingly; absolute imports resolve via + # sys.path[0] = backend root. + ) + assert spec and spec.loader, "could not build spec for routes/inference.py" + module = importlib.util.module_from_spec(spec) + sys.modules["routes.inference"] = module + # Round 15 P3 #9: drop the half-initialised module from + # sys.modules if exec_module() raises, otherwise later tests pick + # up the poisoned entry and report a misleading AttributeError + # instead of the original ImportError. + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop("routes.inference", None) + raise + return module + + +class _FakeBackend: + def __init__(self) -> None: + self._loaded = False + self._repo: str | None = None + self.calls: list[dict] = [] + + @property + def is_loaded(self) -> bool: + return self._loaded + + def status(self) -> dict: + return { + "is_loaded": self._loaded, + "is_loading": False, + "repo_id": self._repo, + "family": "flux.2-klein" if self._loaded else None, + "pipeline_class": "Flux2KleinPipeline" if self._loaded else None, + "base_repo": "black-forest-labs/FLUX.2-klein" if self._loaded else None, + "gguf_filename": None, + "active_repo_id": self._repo, + "active_base_repo": ( + "black-forest-labs/FLUX.2-klein" if self._loaded else None + ), + # Round 14: guard-facing GGUF filename is now the full + # caller-supplied value, but this fake never sets one so + # both active and pending stay None. + "active_gguf_filename": None, + "pending_repo_id": None, + "pending_base_repo": None, + "pending_gguf_filename": None, + "device": "cpu", + "dtype": "torch.bfloat16", + "loaded_at": 0, + "last_error": None, + "supported_families": [], + } + + def load_model(self, repo_id, **kw): + self.calls.append({"op": "load", "repo_id": repo_id, **kw}) + self._loaded = True + self._repo = repo_id + return self.status() + + def unload_model(self) -> dict: + self._loaded = False + self._repo = None + return {"is_loaded": False} + + def generate_image(self, **kw): + self.calls.append({"op": "generate", **kw}) + return Image.new("RGB", (kw["width"], kw["height"]), color = (123, 45, 67)) + + def generate_image_with_metadata(self, **kw): + image = self.generate_image(**kw) + meta = { + "model": self._repo, + "family": "flux.2-klein" if self._loaded else None, + } + return image, meta + + +@pytest.fixture +def app_with_stub(monkeypatch): + """Build a FastAPI app that mounts the real inference router with + auth disabled and the diffusion backend swapped for a stub.""" + inf = _import_inference_module() + import core.inference.diffusion as d + + stub = _FakeBackend() + # Override the singleton accessor the route uses. + monkeypatch.setattr(d, "get_diffusion_backend", lambda: stub) + monkeypatch.setattr(inf, "_get_diffusion_backend", lambda: stub) + + app = FastAPI() + # Diffusion image routes live on studio_router so they are NOT + # exposed under /v1 (which would let OpenAI-compat clients + # trigger Studio-only side effects). + app.include_router(inf.router, prefix = "/api/inference") + app.include_router(inf.studio_router, prefix = "/api/inference") + # Bypass auth by overriding the dependency. + from auth.authentication import get_current_subject + + app.dependency_overrides[get_current_subject] = lambda: "test-user" + + return app, stub + + +def test_status_when_unloaded(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + r = c.get("/api/inference/images/status") + assert r.status_code == 200 + body = r.json() + assert body["is_loaded"] is False + assert body["repo_id"] is None + + +def test_generate_without_load_returns_400(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "a red sphere"}, + ) + assert r.status_code == 400 + assert "No diffusion model" in r.json()["detail"] + + +def test_load_then_generate_round_trip(app_with_stub): + app, stub = app_with_stub + c = TestClient(app) + + r = c.post( + "/api/inference/images/load", + json = { + "repo_id": "unsloth/FLUX.2-klein-4B-GGUF", + "gguf_filename": "flux-2-klein-4b-Q4_K_S.gguf", + }, + ) + assert r.status_code == 200, r.text + assert r.json()["is_loaded"] is True + + r = c.post( + "/api/inference/images/generate", + json = { + "prompt": "a tiny synth-pop album cover", + "width": 256, + "height": 256, + "num_inference_steps": 4, + "seed": 7, + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["image_b64"] + assert body["image_mime"] == "image/png" + assert body["width"] == 256 + assert body["height"] == 256 + assert body["seed"] == 7 + assert body["duration_ms"] >= 0 + + # Round-trip the base64 -> PIL to confirm it is a real PNG of the + # right size and not, say, an empty string. + import base64 + import io + + raw = base64.b64decode(body["image_b64"]) + decoded = Image.open(io.BytesIO(raw)) + assert decoded.format == "PNG" + assert decoded.size == (256, 256) + + # Backend stub should have recorded both calls. + ops = [c["op"] for c in stub.calls] + assert ops == ["load", "generate"] + + +def test_generate_rejects_off_grid_size(app_with_stub): + app, stub = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = { + "repo_id": "unsloth/FLUX.2-klein-4B-GGUF", + "gguf_filename": "x.gguf", + }, + ) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "x", "width": 513, "height": 512}, + ) + # Pydantic v2 wraps validator errors in 422 by default. + assert r.status_code in (400, 422), r.text + + +def test_unload_clears_state(app_with_stub): + app, _ = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = {"repo_id": "unsloth/FLUX.2-klein-4B-GGUF", "gguf_filename": "x.gguf"}, + ) + r = c.post("/api/inference/images/unload") + assert r.status_code == 200 + assert r.json()["is_loaded"] is False + r = c.get("/api/inference/images/status") + assert r.json()["is_loaded"] is False + + +def test_load_rejects_embedded_hf_token(app_with_stub): + """Round 15 P1 #5: URL-embedded ``hf_xxxxx`` tokens in repo_id / + base_repo must be rejected with 422 so they never reach + ``self._repo_id`` and get echoed back by ``status()``.""" + app, _ = app_with_stub + c = TestClient(app) + r = c.post( + "/api/inference/images/load", + json = { + "repo_id": "https://hf_abcdefghij0123456789@huggingface.co/owner/repo", + }, + ) + assert r.status_code == 422, r.text + body = r.json() + text = repr(body).lower() + assert "hf_token" in text or "embed" in text + # base_repo is also rejected. + r = c.post( + "/api/inference/images/load", + json = { + "repo_id": "owner/repo", + "gguf_filename": "x.gguf", + "base_repo": "https://hf_abcdefghij0123456789@huggingface.co/base/repo", + }, + ) + assert r.status_code == 422, r.text + + +def test_load_rejects_control_chars_in_repo_id(app_with_stub): + """Newline-laden repo ids must be rejected by Pydantic BEFORE the + log line that echoes them. Catches log-injection from authenticated + callers (issues a 422 instead of forging a fake log line).""" + app, _ = app_with_stub + c = TestClient(app) + r = c.post( + "/api/inference/images/load", + json = {"repo_id": "owner/model\nFAKE_LOG_LINE"}, + ) + assert r.status_code == 422, r.text + body = r.json() + text = repr(body).lower() + assert "control" in text or "repo_id" in text + + +def test_generate_rejects_oversize_seed(app_with_stub): + """Huge seeds raise inside torch.Generator.manual_seed; Pydantic + must clamp first with a 422 instead of a 500 traceback.""" + app, _ = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = {"repo_id": "unsloth/FLUX.2-klein-4B-GGUF", "gguf_filename": "x.gguf"}, + ) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "x", "seed": 2**100}, + ) + assert r.status_code == 422, r.text + + +def test_generate_accepts_uint64_max_seed(app_with_stub): + """Boundary value: 2**64 - 1 (uint64 max) is the largest seed + torch.Generator on CPU accepts; reject would frustrate users + who paste large seeds from other tooling.""" + app, _ = app_with_stub + c = TestClient(app) + c.post( + "/api/inference/images/load", + json = {"repo_id": "unsloth/FLUX.2-klein-4B-GGUF", "gguf_filename": "x.gguf"}, + ) + r = c.post( + "/api/inference/images/generate", + json = {"prompt": "x", "seed": (2**64) - 1}, + ) + # The fake backend returns 200 on success; we only care that the + # request did NOT 422 on seed bounds. + assert r.status_code != 422, r.text diff --git a/studio/backend/tests/test_llama_cpp_no_context_shift.py b/studio/backend/tests/test_llama_cpp_no_context_shift.py index b9f25faf88..c49900812e 100644 --- a/studio/backend/tests/test_llama_cpp_no_context_shift.py +++ b/studio/backend/tests/test_llama_cpp_no_context_shift.py @@ -66,15 +66,24 @@ def _load_model_source() -> str: - """Return the source of ``LlamaCppBackend.load_model``. - - Using ``inspect.getsource`` instead of reading the file directly - scopes the assertions to the function that actually launches - llama-server, so neither the presence check nor the location check - can be fooled by a stray occurrence of ``"--no-context-shift"`` - elsewhere in the module. + """Return the source of ``LlamaCppBackend.load_model`` PLUS the + internal ``_load_model_impl_locked`` body it delegates to. + + Studio's diffusion PR split ``load_model`` into a thin wrapper + that publishes ``_loading_model_identifier`` under + ``_serial_load_lock`` and an inner ``_load_model_impl_locked`` + body that actually spawns llama-server. The launch flags and the + ``_wait_for_vram_settle`` call now live in the inner method, so + inspecting only ``load_model`` would miss them. Concatenating the + two sources keeps these source-inspection regression tests + working without weakening the scope (we still only look at the + two load entry points, not the entire module). """ - return inspect.getsource(llama_cpp_module.LlamaCppBackend.load_model) + parts = [inspect.getsource(llama_cpp_module.LlamaCppBackend.load_model)] + impl = getattr(llama_cpp_module.LlamaCppBackend, "_load_model_impl_locked", None) + if impl is not None: + parts.append(inspect.getsource(impl)) + return "\n".join(parts) def test_no_context_shift_is_in_load_model(): diff --git a/studio/backend/tests/test_llama_cpp_wait_for_vram_settle.py b/studio/backend/tests/test_llama_cpp_wait_for_vram_settle.py index 00295d6283..5735ed67bc 100644 --- a/studio/backend/tests/test_llama_cpp_wait_for_vram_settle.py +++ b/studio/backend/tests/test_llama_cpp_wait_for_vram_settle.py @@ -271,10 +271,19 @@ def test_load_model_calls_helper_outside_lock_and_uses_last_kill_timestamp(): """Pin the call site: outside Phase 3 lock, gated on the timestamp, no ``had_live_process`` in-band flag regression. Mirrors the ``inspect.getsource`` pattern from ``test_llama_cpp_no_context_shift``. + + Studio's diffusion PR split ``load_model`` into a thin wrapper + + ``_load_model_impl_locked`` that actually launches llama-server, so + look at both sources to keep the assertions scoped to the load entry + points and not the entire module. """ import inspect - src = inspect.getsource(LlamaCppBackend.load_model) + parts = [inspect.getsource(LlamaCppBackend.load_model)] + impl = getattr(LlamaCppBackend, "_load_model_impl_locked", None) + if impl is not None: + parts.append(inspect.getsource(impl)) + src = "\n".join(parts) assert "_wait_for_vram_settle" in src assert "since_kill" in src assert "self._last_kill_monotonic" in src diff --git a/studio/backend/utils/datasets/llm_assist.py b/studio/backend/utils/datasets/llm_assist.py index 4c66d2ebf6..91a7bb8ea9 100644 --- a/studio/backend/utils/datasets/llm_assist.py +++ b/studio/backend/utils/datasets/llm_assist.py @@ -18,7 +18,9 @@ import os import re import textwrap +import threading import time +from collections import Counter from itertools import islice from typing import Any, Optional @@ -31,6 +33,137 @@ README_MAX_CHARS = 1500 +# Round 26 P1 #13 / #14: helper/advisor run on PRIVATE LlamaCppBackend +# instances. Expose loading repo ids through thread-safe Counters so +# DELETE /api/models/delete-cached can block while a helper or +# advisor still owns the cache. +# +# Round 28 P1 #2: split into CACHE vs GPU refcounts. precache_helper_gguf +# downloads files (cache ownership) without occupying VRAM (GPU +# ownership), so collapsing them caused the public GPU handoffs to +# 503 during a background precache that did not need the GPU. +# * CACHE: blocks delete-cache for any active downloader / loader +# * GPU : blocks public chat / training / export / diffusion loads +_HELPER_ADVISOR_CACHE_REFCOUNT: Counter[str] = Counter() +_HELPER_ADVISOR_GPU_REFCOUNT: Counter[str] = Counter() +# Round 30 P1 #7-#10: counter of public GPU workloads (chat / +# diffusion / training / export) that have passed the helper-busy +# snapshot but have not yet flipped their public ownership flags +# (``llama.is_loaded`` / ``loading_model_identifier`` / +# ``current_checkpoint`` / ``is_training_active``). Helper / advisor +# starts consult this so they cannot win the start lock and race a +# public load that already destroyed the previous owner. +_PUBLIC_LOAD_PENDING_COUNT: Counter[str] = Counter() +_HELPER_ADVISOR_LOCK = threading.Lock() +# Round 28 P1 #7 / #8 / #10: serialize helper / advisor STARTS so two +# concurrent invocations cannot both pass the busy precheck before +# either registers. Held only across the precheck + register window, +# not across the full helper run. +# Round 30 P1 #7-#10: public GPU loads also enter under this lock to +# publish their pending counter so a concurrent helper / advisor +# start sees the pending public owner and refuses VRAM. +_HELPER_ADVISOR_START_LOCK = threading.Lock() + + +def helper_advisor_owns_repo(repo_id: str) -> bool: + """Return True if any helper/advisor activity (precache OR live + helper / advisor load) currently owns this HF repo id.""" + if not repo_id: + return False + needle = repo_id.lower() + with _HELPER_ADVISOR_LOCK: + return _HELPER_ADVISOR_CACHE_REFCOUNT.get(needle, 0) > 0 + + +def helper_advisor_busy() -> bool: + """True if any helper/advisor load is currently OCCUPYING THE GPU. + Round 28 P1 #2: must not return True for a precache-only download + (it owns disk cache, not VRAM).""" + with _HELPER_ADVISOR_LOCK: + return sum(_HELPER_ADVISOR_GPU_REFCOUNT.values()) > 0 + + +def _register_helper_advisor_repo(repo_id: str, *, gpu_owner: bool = True) -> None: + """Register a helper/advisor activity. Set ``gpu_owner=False`` for + precache-only downloads that need cache-delete protection but do + not load weights into VRAM.""" + if not repo_id: + return + needle = repo_id.lower() + with _HELPER_ADVISOR_LOCK: + _HELPER_ADVISOR_CACHE_REFCOUNT[needle] += 1 + if gpu_owner: + _HELPER_ADVISOR_GPU_REFCOUNT[needle] += 1 + + +def _unregister_helper_advisor_repo(repo_id: str, *, gpu_owner: bool = True) -> None: + if not repo_id: + return + needle = repo_id.lower() + with _HELPER_ADVISOR_LOCK: + _HELPER_ADVISOR_CACHE_REFCOUNT[needle] -= 1 + if _HELPER_ADVISOR_CACHE_REFCOUNT[needle] <= 0: + _HELPER_ADVISOR_CACHE_REFCOUNT.pop(needle, None) + if gpu_owner: + _HELPER_ADVISOR_GPU_REFCOUNT[needle] -= 1 + if _HELPER_ADVISOR_GPU_REFCOUNT[needle] <= 0: + _HELPER_ADVISOR_GPU_REFCOUNT.pop(needle, None) + + +def _publish_public_load_pending(workload: str) -> None: + """Mark a public GPU workload as mid-handoff. Must be called under + ``_HELPER_ADVISOR_START_LOCK`` immediately after the helper-busy + snapshot succeeded (round 30 P1 #7-#10).""" + if not workload: + return + needle = workload.lower() + with _HELPER_ADVISOR_LOCK: + _PUBLIC_LOAD_PENDING_COUNT[needle] += 1 + + +def _release_public_load_pending(workload: str) -> None: + """Decrement the pending public-load counter once per matched + publish. Safe to call in finally even if the load failed.""" + if not workload: + return + needle = workload.lower() + with _HELPER_ADVISOR_LOCK: + _PUBLIC_LOAD_PENDING_COUNT[needle] -= 1 + if _PUBLIC_LOAD_PENDING_COUNT[needle] <= 0: + _PUBLIC_LOAD_PENDING_COUNT.pop(needle, None) + + +def public_load_pending(*, excluding: str | None = None) -> bool: + """True if any public GPU workload has passed its helper-busy + snapshot but not yet flipped its public ownership flags. Helper / + advisor starts treat this as busy so they cannot race a public + load mid-handoff. + + Round 38 P1: ``excluding`` lets a route-wrapped backend call + skip the marker its own route layer already published (e.g. the + diffusion route publishes ``diffusion`` before calling into + ``backend.load_model``, which publishes ``diffusion-backend`` -- + the backend should ignore its own ``diffusion`` marker so the + parity check does not self-block) while still seeing every + OTHER in-flight public workload.""" + ignored = excluding.lower() if excluding else None + with _HELPER_ADVISOR_LOCK: + return any( + count > 0 and workload != ignored + for workload, count in _PUBLIC_LOAD_PENDING_COUNT.items() + ) + + +def public_load_pending_for(workload: str) -> bool: + """True if a specific public GPU workload is mid-handoff. Used by + release helpers to refuse a destructive teardown while the matching + /export/* or /chat /load_* route is still in its publish window.""" + if not workload: + return False + needle = workload.lower() + with _HELPER_ADVISOR_LOCK: + return _PUBLIC_LOAD_PENDING_COUNT.get(needle, 0) > 0 + def _strip_think_tags(text: str) -> str: """Strip ... reasoning blocks emitted by some models. @@ -72,6 +205,12 @@ def precache_helper_gguf(): "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) + # Round 27 P1 #4: register the repo so DELETE /api/models/delete-cached + # cannot rmtree the cache directory while we are mid-download. + # Round 28 P1 #2: precache only downloads files; it does NOT occupy + # VRAM. Use gpu_owner=False so helper_advisor_busy() does not block + # public GPU workloads during a background pre-cache. + _register_helper_advisor_repo(repo, gpu_owner = False) try: from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars @@ -103,12 +242,143 @@ def precache_helper_gguf(): except Exception as e: logger.warning(f"Failed to pre-cache helper GGUF: {e}") finally: + _unregister_helper_advisor_repo(repo, gpu_owner = False) try: enable_progress_bars() except Exception as e: pass +def _diffusion_image_model_busy() -> bool: + """Round 22 P1 #2 / #3: helper / advisor GGUFs share VRAM with + the Images page diffusion pipeline. Public chat / training / + export routes call the strict ``_release_diffusion_for`` helper + before allocating, but these dataset-side helpers used to load + llama-server directly with no diffusion guard at all. Skip the + helper GGUF when ``DiffusionBackend.status()`` reports loaded / + loading so we do not double-own VRAM. Fail closed (treat as + busy) on any status() error to preserve the resident image + model rather than racing it for memory. + """ + try: + from core.inference.diffusion import get_diffusion_backend + except Exception: + return False + try: + status = get_diffusion_backend().status() + except Exception: + return True + return bool(status.get("is_loaded") or status.get("is_loading")) + + +def _gpu_workload_busy_for_helper() -> bool: + """Round 23 P1 #3 / #4: the diffusion-only guard from round 22 + let the helper / advisor GGUF run on top of a live training run + or a resident export checkpoint. Extend the busy check to those + workloads too so any GPU owner (Images, Training, Export) + blocks the helper instead of double-owning VRAM. Each step + fails closed: an unverifiable status counts as busy so the + user's primary workload is preserved over the optional helper. + + Round 24 P1 #1: extended to also catch a Chat-backend GPU owner. + The helper GGUF used to run on top of a loaded GGUF chat model + (llama-server) or safetensors chat model and OOM their shared + GPU; mirror the diffusion check by inspecting llama + ``is_loaded`` / ``is_active`` / ``loading_model_identifier`` and + safetensors ``active_model_name`` / ``loading_models``. + + Round 28 P1 #9: also catch another helper / advisor that already + owns a private LlamaCppBackend. Without this two concurrent + helpers could both pass the precheck and OOM each other. + """ + if helper_advisor_busy(): + logger.info( + "Skipping helper GGUF while another helper/advisor is using the GPU" + ) + return True + # Round 30 P1 #7-#10: a public GPU load (chat / diffusion / training / + # export) that has passed its busy snapshot but not yet flipped its + # public ownership flags is still mid-handoff. Refuse so the helper + # does not race it for VRAM after the previous owner was torn down. + if public_load_pending(): + logger.info("Skipping helper GGUF while a public GPU load is mid-handoff") + return True + if _diffusion_image_model_busy(): + return True + + try: + from routes.inference import get_llama_cpp_backend + except Exception: + pass + else: + try: + llama = get_llama_cpp_backend() + if ( + getattr(llama, "is_loaded", False) + or getattr(llama, "is_active", False) + or getattr(llama, "loading_model_identifier", None) + ): + logger.info( + "Skipping helper GGUF while a GGUF chat model is loaded/loading" + ) + return True + except Exception: + logger.info( + "Skipping helper GGUF because llama-server status is unavailable" + ) + return True + + try: + from core.inference import get_inference_backend + except Exception: + pass + else: + try: + inf = get_inference_backend() + active = getattr(inf, "active_model_name", None) + loading = set(getattr(inf, "loading_models", set()) or set()) + if active or loading: + logger.info( + "Skipping helper GGUF while a safetensors chat model is loaded/loading" + ) + return True + except Exception: + logger.info( + "Skipping helper GGUF because safetensors chat status is unavailable" + ) + return True + + try: + from core.training import get_training_backend + except Exception: + pass + else: + try: + if get_training_backend().is_training_active(): + logger.info("Skipping helper GGUF while training is active") + return True + except Exception: + logger.info("Skipping helper GGUF because training status is unavailable") + return True + + try: + from core.export import get_export_backend + except Exception: + return False + + try: + exp = get_export_backend() + is_active = getattr(exp, "is_export_active", None) + if (is_active and is_active()) or getattr(exp, "current_checkpoint", None): + logger.info("Skipping helper GGUF while export owns the GPU") + return True + except Exception: + logger.info("Skipping helper GGUF because export status is unavailable") + return True + + return False + + def _run_with_helper(prompt: str, max_tokens: int = 256) -> Optional[str]: """ Load helper model, run one chat completion, unload. @@ -118,13 +388,28 @@ def _run_with_helper(prompt: str, max_tokens: int = 256) -> Optional[str]: if os.environ.get("UNSLOTH_HELPER_MODEL_DISABLE", "").strip() in ("1", "true"): return None + # Round 23 P1 #3: round 22 only guarded against a busy + # diffusion pipeline. Training / export own the same GPU too, + # so use the broader helper that gates on all three workloads. + # Round 28 P1 #7 / #10: serialize the busy check + register pair + # so two concurrent helper invocations cannot both pass the + # precheck before either registers and then OOM each other. repo = os.environ.get("UNSLOTH_HELPER_MODEL_REPO", DEFAULT_HELPER_MODEL_REPO) variant = os.environ.get( "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) - + with _HELPER_ADVISOR_START_LOCK: + if _gpu_workload_busy_for_helper(): + return None + _register_helper_advisor_repo(repo) backend = None try: + # Round 26 P1 #1 / #3 / #13 / #14: use a PRIVATE backend so the + # helper can never preempt or be preempted by the user's + # chat backend and cannot accidentally unload it in finally. + # The active repo is published via _register_helper_advisor_repo + # above so DELETE /api/models/delete-cached can still block the + # cache rmtree while the helper is downloading or mmap'ing. from core.inference.llama_cpp import LlamaCppBackend backend = LlamaCppBackend() @@ -176,6 +461,7 @@ def _run_with_helper(prompt: str, max_tokens: int = 256) -> Optional[str]: logger.info("Helper model unloaded") except Exception: pass + _unregister_helper_advisor_repo(repo) # ─── Public API ─────────────────────────────────────────────────────── @@ -508,13 +794,26 @@ def _run_multi_pass_advisor( if os.environ.get("UNSLOTH_HELPER_MODEL_DISABLE", "").strip() in ("1", "true"): return None + # Round 23 P1 #4: extend the round 22 diffusion-only check to + # training + export so the advisor cannot race the user's + # active workload for GPU memory. + # Round 28 P1 #8 / #10: serialize the precheck + register pair so + # two concurrent advisor invocations cannot both pass before + # either registers and then OOM each other. repo = os.environ.get("UNSLOTH_HELPER_MODEL_REPO", DEFAULT_HELPER_MODEL_REPO) variant = os.environ.get( "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) - + with _HELPER_ADVISOR_START_LOCK: + if _gpu_workload_busy_for_helper(): + return None + _register_helper_advisor_repo(repo) backend = None try: + # Round 26 P1 #2 / #4 / #13 / #14: mirror ``_run_with_helper`` + # and use a PRIVATE backend. Round 25's global-backend swap + # introduced chat-evict races and finally-eviction bugs. + # The registry above keeps delete-cache safe. from core.inference.llama_cpp import LlamaCppBackend backend = LlamaCppBackend() @@ -849,6 +1148,7 @@ def _run_multi_pass_advisor( logger.info("Advisor model unloaded") except Exception: pass + _unregister_helper_advisor_repo(repo) def llm_conversion_advisor( diff --git a/studio/frontend/src/app/router.tsx b/studio/frontend/src/app/router.tsx index c7bc0440bd..b50f3fe618 100644 --- a/studio/frontend/src/app/router.tsx +++ b/studio/frontend/src/app/router.tsx @@ -9,6 +9,7 @@ import { Route as dataRecipeRoute } from "./routes/data-recipes.$recipeId"; import { Route as chatRoute } from "./routes/chat"; import { Route as exportRoute } from "./routes/export"; import { Route as gridTestRoute } from "./routes/grid-test"; +import { Route as imagesRoute } from "./routes/images"; import { Route as indexRoute } from "./routes/index"; import { Route as loginRoute } from "./routes/login"; import { Route as onboardingRoute } from "./routes/onboarding"; @@ -26,6 +27,7 @@ const routeTree = rootRoute.addChildren([ studioRoute, chatRoute, exportRoute, + imagesRoute, dataRecipesRoute, dataRecipeRoute, ]); diff --git a/studio/frontend/src/app/routes/images.tsx b/studio/frontend/src/app/routes/images.tsx new file mode 100644 index 0000000000..1761612140 --- /dev/null +++ b/studio/frontend/src/app/routes/images.tsx @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { createRoute } from "@tanstack/react-router"; +import { lazy } from "react"; +import { requireAuth } from "../auth-guards"; +import { Route as rootRoute } from "./__root"; + +const ImagesPage = lazy(() => + import("@/features/images").then((m) => ({ + default: m.ImagesPage, + })), +); + +export const Route = createRoute({ + getParentRoute: () => rootRoute, + path: "/images", + staticData: { title: "Images" }, + beforeLoad: () => requireAuth(), + component: ImagesPage, +}); diff --git a/studio/frontend/src/components/app-sidebar.tsx b/studio/frontend/src/components/app-sidebar.tsx index aac5f8f8a8..9a0830db2d 100644 --- a/studio/frontend/src/components/app-sidebar.tsx +++ b/studio/frontend/src/components/app-sidebar.tsx @@ -50,6 +50,7 @@ import { Globe02Icon, HelpCircleIcon, Logout01Icon, + PaintBrush02Icon, Search01Icon, PowerIcon, PencilEdit02Icon, @@ -497,6 +498,18 @@ export function AppSidebar() { }} /> + { + if (chatOnly) return; + navigate({ to: "/images" }); + closeMobileIfOpen(); + }} + /> + (res: Response): Promise { + if (!res.ok) throw new Error(await readFastApiError(res)); + return (await res.json()) as T; +} + +export async function fetchDiffusionStatus(): Promise { + return parseJson( + await authFetch("/api/inference/images/status"), + ); +} + +export async function loadDiffusionModel( + payload: DiffusionLoadRequest, +): Promise { + return parseJson( + await authFetch("/api/inference/images/load", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }), + ); +} + +export async function unloadDiffusionModel(): Promise<{ is_loaded: boolean }> { + return parseJson<{ is_loaded: boolean }>( + await authFetch("/api/inference/images/unload", { method: "POST" }), + ); +} + +/** JSON.stringify cannot serialise BigInt directly. Pull the seed + * BigInt out, stringify the rest of the payload normally, then + * splice the seed's decimal digits back into the JSON literal at the + * exact ``"seed":`` slot. + * + * Avoids the previous regex-over-JSON approach, which could be + * tripped by a user-supplied prompt that exactly matched the + * sentinel string. With this approach the only thing we touch is + * the literal ``"seed":`` substring we wrote ourselves. + */ +function stringifyWithBigInt(value: DiffusionGenerateRequest): string { + const { seed, ...rest } = value; + if (typeof seed !== "bigint") { + return JSON.stringify(value); + } + // Serialise the rest without seed, then inject the seed at the end + // of the object literal as a JSON integer. Strip the trailing "}" + // and re-append once the field is added. + const base = JSON.stringify(rest); + const inner = base.length === 2 /* '{}' */ ? "" : base.slice(1, -1) + ","; + return `{${inner}"seed":${seed.toString()}}`; +} + +export async function generateDiffusionImage( + payload: DiffusionGenerateRequest, +): Promise { + return parseJson( + await authFetch("/api/inference/images/generate", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: stringifyWithBigInt(payload), + }), + ); +} diff --git a/studio/frontend/src/features/images/images-page.tsx b/studio/frontend/src/features/images/images-page.tsx new file mode 100644 index 0000000000..ff4a839cf3 --- /dev/null +++ b/studio/frontend/src/features/images/images-page.tsx @@ -0,0 +1,620 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { SectionCard } from "@/components/section-card"; +import { Slider } from "@/components/ui/slider"; +import { Spinner } from "@/components/ui/spinner"; +import { Textarea } from "@/components/ui/textarea"; +import { toast } from "@/lib/toast"; +import { PaintBrush02Icon, SparklesIcon, GpuIcon } from "@hugeicons/core-free-icons"; +import { HugeiconsIcon } from "@hugeicons/react"; +import { + fetchDiffusionStatus, + generateDiffusionImage, + loadDiffusionModel, + unloadDiffusionModel, + type DiffusionGenerateResponse, + type DiffusionStatus, +} from "./api"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; + +// Curated short list of working diffusion GGUFs. Picked to span +// size + license so any GPU class has at least one viable option: +// FLUX.2 klein 4B -> ~13 GB VRAM with Q4_K_S, Apache 2.0 +// FLUX.2 klein 9B -> ~17 GB VRAM, FLUX [klein] non-commercial (gated) +// FLUX.2 dev -> ~24+ GB VRAM, FLUX [dev] non-commercial (gated) +// FLUX.1 dev -> ~12 GB VRAM, older but widely tested (gated) +// +// Filenames mirror the Hub canonical case (lowercase 'flux-2-klein-4b') +// and base_repo is set explicitly so the backend never falls back to the +// family default. The CLI on the backend can load anything supported by +// detect_family(); this list just keeps the picker compact for the v1 UI. +const CURATED_MODELS: Array<{ + label: string; + repo_id: string; + default_gguf: string; + base_repo: string; + family: string; + notes: string; +}> = [ + { + label: "FLUX.2 klein base 4B (Q4_K_S, Apache 2.0)", + repo_id: "unsloth/FLUX.2-klein-base-4B-GGUF", + default_gguf: "flux-2-klein-base-4b-Q4_K_S.gguf", + base_repo: "black-forest-labs/FLUX.2-klein-base-4B", + family: "flux.2-klein", + notes: "13 GB VRAM, fastest. Apache 2.0, ungated.", + }, + { + label: "FLUX.2 klein 4B (Q4_K_S, distilled)", + repo_id: "unsloth/FLUX.2-klein-4B-GGUF", + default_gguf: "flux-2-klein-4b-Q4_K_S.gguf", + // Distilled GGUF must pair with the distilled base, not the Base + // checkpoint. The Hub model card for the GGUF lists + // base_model: black-forest-labs/FLUX.2-klein-4B. + base_repo: "black-forest-labs/FLUX.2-klein-4B", + family: "flux.2-klein", + notes: "13 GB VRAM. Distilled klein 4B. Requires HF access to FLUX.2 klein 4B.", + }, + { + label: "FLUX.2 klein 9B (Q4_K_S, gated)", + repo_id: "unsloth/FLUX.2-klein-9B-GGUF", + default_gguf: "flux-2-klein-9b-Q4_K_S.gguf", + base_repo: "black-forest-labs/FLUX.2-klein-9B", + family: "flux.2-klein", + notes: "17 GB VRAM. Higher quality distilled. Requires HF access to FLUX.2 klein 9B.", + }, + { + label: "FLUX.2 dev (Q4_K_S, gated)", + repo_id: "unsloth/FLUX.2-dev-GGUF", + default_gguf: "flux2-dev-Q4_K_S.gguf", + base_repo: "black-forest-labs/FLUX.2-dev", + family: "flux.2", + notes: "24+ GB VRAM. Requires HF access to FLUX.2 dev.", + }, + { + label: "FLUX.1 dev (Q4_K_S, city96, gated)", + repo_id: "city96/FLUX.1-dev-gguf", + default_gguf: "flux1-dev-Q4_K_S.gguf", + base_repo: "black-forest-labs/FLUX.1-dev", + family: "flux.1", + notes: "12 GB VRAM. Older but widely tested. Requires HF access to FLUX.1 dev.", + }, +]; + +const DEFAULT_PRESET = CURATED_MODELS[0]; + +const RESOLUTION_PRESETS: Array<{ label: string; w: number; h: number }> = [ + { label: "Square 1024", w: 1024, h: 1024 }, + { label: "Square 768", w: 768, h: 768 }, + { label: "Square 512", w: 512, h: 512 }, + { label: "Portrait 832x1216", w: 832, h: 1216 }, + { label: "Landscape 1216x832", w: 1216, h: 832 }, +]; + +export function ImagesPage() { + const [status, setStatus] = useState(null); + const [refreshingStatus, setRefreshingStatus] = useState(false); + const [busy, setBusy] = useState<"idle" | "loading" | "unloading" | "generating">("idle"); + + const [presetIndex, setPresetIndex] = useState(0); + const [customRepoId, setCustomRepoId] = useState(""); + const [customGguf, setCustomGguf] = useState(""); + const [customBaseRepo, setCustomBaseRepo] = useState(""); + const [customFamily, setCustomFamily] = useState("auto"); + const [useCustom, setUseCustom] = useState(false); + const [hfToken, setHfToken] = useState(""); + + const [prompt, setPrompt] = useState("a tiny ginger sloth coding in a sunlit treehouse, photorealistic"); + const [negativePrompt, setNegativePrompt] = useState(""); + const [steps, setSteps] = useState(24); + const [guidance, setGuidance] = useState(3.5); + const [resolutionIdx, setResolutionIdx] = useState(0); + const [seed, setSeed] = useState(""); + + const [results, setResults] = useState([]); + const lastErrorRef = useRef(null); + + const preset = CURATED_MODELS[presetIndex] ?? DEFAULT_PRESET; + const resolution = RESOLUTION_PRESETS[resolutionIdx]; + + // Round 30 P2 #12: split the fetch from the spinner toggle so the + // mount + auto-poll effects can call the fetch without the + // synchronous setRefreshingStatus(true) that tripped + // react-hooks/set-state-in-effect. + const fetchAndUpdateStatus = useCallback(async () => { + try { + const next = await fetchDiffusionStatus(); + setStatus(next); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + if (lastErrorRef.current !== msg) { + lastErrorRef.current = msg; + toast.error("Could not fetch image-model status", { description: msg }); + } + } + }, []); + + const refreshStatus = useCallback(async () => { + setRefreshingStatus(true); + try { + await fetchAndUpdateStatus(); + } finally { + setRefreshingStatus(false); + } + }, [fetchAndUpdateStatus]); + + useEffect(() => { + // Defer the mount fetch out of the synchronous effect body so the + // setStatus call inside fetchAndUpdateStatus does not trip the + // react-hooks/set-state-in-effect rule. + const id = window.setTimeout(() => { + void fetchAndUpdateStatus(); + }, 0); + return () => window.clearTimeout(id); + }, [fetchAndUpdateStatus]); + + // Round 27 P2: when the backend is mid-load (is_loading=true) the + // status label froze at "Loading..." until the user clicked + // Refresh. Auto-poll every 2 s while a load is in flight so the + // UI tracks real backend progress. + useEffect(() => { + if (!status?.is_loading) return; + const id = window.setInterval(() => { + void fetchAndUpdateStatus(); + }, 2000); + return () => window.clearInterval(id); + }, [status?.is_loading, fetchAndUpdateStatus]); + + const handleLoad = useCallback(async () => { + setBusy("loading"); + try { + const repo = useCustom ? customRepoId.trim() : preset.repo_id; + const gguf = useCustom ? customGguf.trim() || undefined : preset.default_gguf; + // Custom mode lets the user pin a family explicitly because + // detect_family is substring-based and exotic repo names (custom + // fine-tunes, third-party mirrors) frequently fail to match. + // "auto" leaves the override blank and lets the backend infer. + const family = useCustom + ? customFamily === "auto" + ? undefined + : customFamily + : preset.family; + // Always pass base_repo for curated entries; custom-repo mode + // now also lets the user pin one because private / mirrored + // GGUFs (e.g. a 9B klein transformer) would otherwise fall + // back to the family-default 4B base and 500 on load. Empty + // string still falls back to the backend's smart-base / + // repo-id defaults. + const baseRepo = useCustom + ? customBaseRepo.trim() || undefined + : preset.base_repo; + if (!repo) { + toast.error("Pick a model first"); + return; + } + const next = await loadDiffusionModel({ + repo_id: repo, + gguf_filename: gguf, + base_repo: baseRepo, + family, + hf_token: hfToken.trim() || undefined, + }); + setStatus(next); + toast.success("Loaded image model", { description: next.repo_id ?? undefined }); + } catch (err) { + toast.error("Failed to load image model", { + description: err instanceof Error ? err.message : String(err), + }); + // Backend clears its old pipeline before allocating the new one; + // a failed swap leaves status.is_loaded=false while our local + // copy still says loaded. Re-fetch so Generate disables and the + // user does not see a stale "Loaded:" label. + await refreshStatus(); + } finally { + setBusy("idle"); + } + }, [useCustom, customRepoId, customGguf, customBaseRepo, customFamily, preset, hfToken, refreshStatus]); + + const handleUnload = useCallback(async () => { + setBusy("unloading"); + try { + await unloadDiffusionModel(); + await refreshStatus(); + } catch (err) { + toast.error("Failed to unload image model", { + description: err instanceof Error ? err.message : String(err), + }); + // Round 27 P2: a partial unload (subprocess refused to terminate, + // 503 from the backend) used to leave the UI showing the old + // "Loaded:" label even though the backend state was half torn + // down. Refresh so the button states match reality (mirrors + // handleLoad above which always re-fetches on catch). + await refreshStatus(); + } finally { + setBusy("idle"); + } + }, [refreshStatus]); + + const handleGenerate = useCallback(async () => { + if (!prompt.trim()) { + toast.error("Prompt is empty"); + return; + } + setBusy("generating"); + try { + // Reject non-integer seeds and clamp to the [-2^63, 2^64 - 1] + // range the backend's torch.Generator can actually pack. JSON + // serialises BigInts as plain integers, so we keep the wire + // format compatible and avoid the Number(seed) precision loss + // (>= 2^53 silently rounds, producing a different image than + // the seed the user typed). When the seed fits a safe integer + // it goes through unchanged; larger seeds ride along as their + // BigInt-derived string via the wire-format BigInt JSON helper + // in the api layer. + const seedStr = seed.trim(); + let parsedSeed: number | bigint | undefined; + if (seedStr) { + if (!/^-?\d+$/.test(seedStr)) { + toast.error("Seed must be an integer"); + return; + } + let big: bigint; + try { + big = BigInt(seedStr); + } catch { + toast.error("Seed must be an integer"); + return; + } + const SEED_MIN = -(BigInt(2) ** BigInt(63)); + const SEED_MAX = BigInt(2) ** BigInt(64) - BigInt(1); + if (big < SEED_MIN || big > SEED_MAX) { + toast.error( + "Seed must be in [-2^63, 2^64 - 1] (the torch.Generator range)", + ); + return; + } + // Use a plain Number when it fits a safe integer so the + // existing api.ts JSON serialiser does not break on BigInt; + // otherwise pass the BigInt and let api.ts emit it as a JSON + // number via a custom replacer. + const SAFE_MAX = BigInt(Number.MAX_SAFE_INTEGER); + const SAFE_MIN = -SAFE_MAX; + parsedSeed = big >= SAFE_MIN && big <= SAFE_MAX ? Number(big) : big; + } + const out = await generateDiffusionImage({ + prompt, + negative_prompt: negativePrompt.trim() || undefined, + num_inference_steps: steps, + guidance_scale: guidance, + width: resolution.w, + height: resolution.h, + seed: parsedSeed, + }); + setResults((prev) => [out, ...prev].slice(0, 12)); + } catch (err) { + toast.error("Image generation failed", { + description: err instanceof Error ? err.message : String(err), + }); + } finally { + setBusy("idle"); + } + }, [prompt, negativePrompt, steps, guidance, resolution, seed]); + + const statusLabel = useMemo(() => { + if (!status) return refreshingStatus ? "Checking..." : "Not loaded"; + if (status.is_loading) return "Loading..."; + if (status.is_loaded) { + const dev = status.device ? ` on ${status.device}` : ""; + return `Loaded: ${status.repo_id ?? "(unknown)"} (${status.family ?? "unknown"})${dev}`; + } + return "Not loaded"; + }, [status, refreshingStatus]); + + // FLUX.2 / FLUX.2 klein pipelines do NOT accept negative_prompt and + // would 500 if we sent one through. The backend strips the field + // defensively but hiding it client-side keeps the UI honest. + // Round 29 P2 #12: also honour the user-picked customFamily when no + // model is loaded yet, so a Custom HF repo with family flux.2 / + // flux.2-klein hides the negative-prompt field correctly. + const supportsNegativePrompt = useMemo(() => { + const family = status?.family; + if (!family) { + let candidate: string | undefined; + if (useCustom) { + candidate = customFamily === "auto" ? undefined : customFamily; + } else { + candidate = preset.family; + } + if (!candidate) return true; + return !candidate.startsWith("flux.2"); + } + return !family.startsWith("flux.2"); + }, [status, useCustom, customFamily, preset.family]); + + return ( +
+ } + title="Local image generation" + description={ + "Run diffusion GGUFs from Hugging Face on your own GPU. " + + "Pick a curated FLUX.2 model or paste any unsloth/* GGUF repo." + } + > +
+
+ + + {!useCustom && ( +

{preset.notes}

+ )} +
+ + {useCustom && ( +
+ + setCustomRepoId(e.target.value)} + placeholder="unsloth/FLUX.2-klein-4B-GGUF" + /> + + setCustomGguf(e.target.value)} + placeholder="FLUX.2-klein-4B-Q4_K_S.gguf" + /> + + setCustomBaseRepo(e.target.value)} + placeholder="black-forest-labs/FLUX.2-klein-9B" + /> +

+ {"Optional. Defaults to the family base. Set this when "} + {"your GGUF expects a non-default base (for example a 9B "} + {"transformer that would otherwise fall back to a 4B base)."} +

+ + +

+ {"Set this when your repo name does not contain "} + {"a recognised family substring (e.g. private fine-tunes)."} +

+
+ )} + +
+ + setHfToken(e.target.value)} + placeholder="hf_..." + autoComplete="off" + /> +
+ +
+ + + + + {statusLabel} + +
+
+
+ + } + title="Prompt" + description="The pipeline runs on the GPU you launched Unsloth Studio on." + > +
+
+ +