diff --git a/studio/backend/routes/data_recipe/jobs.py b/studio/backend/routes/data_recipe/jobs.py index 1d5eceee03..00546b47a4 100644 --- a/studio/backend/routes/data_recipe/jobs.py +++ b/studio/backend/routes/data_recipe/jobs.py @@ -5,7 +5,9 @@ from __future__ import annotations +from datetime import timedelta from typing import Any +from urllib.parse import urlparse from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -26,6 +28,161 @@ router = APIRouter() +def _resolve_local_v1_endpoint(request: Request) -> str: + """Return the loopback /v1 URL for the actual backend listen port. + + Resolution order: + 1. ``app.state.server_port`` - explicitly published by run.py after + the uvicorn server has bound. This is the most reliable source + because it survives reverse proxies, TLS terminators and tunnels. + 2. ``request.scope["server"]`` - the real (host, port) tuple uvicorn + sets when the request is dispatched. Used when Studio is started + outside ``run_server`` (e.g. ``uvicorn studio.backend.main:app``). + 3. ``request.base_url`` parsed - last resort for test fixtures that + do not route through a live uvicorn server. + """ + port: Any = getattr(request.app.state, "server_port", None) + if not isinstance(port, int) or port <= 0: + server = request.scope.get("server") + if ( + isinstance(server, tuple) + and len(server) >= 2 + and isinstance(server[1], int) + and server[1] > 0 + ): + port = server[1] + else: + parsed = urlparse(str(request.base_url)) + port = parsed.port if parsed.port is not None else 8888 + return f"http://127.0.0.1:{int(port)}/v1" + + +def _used_llm_model_aliases(recipe: dict[str, Any]) -> set[str]: + """Return the set of model_aliases that are actually referenced by an + LLM column. Used to narrow the "Chat model loaded" gate so that orphan + model_config nodes on the canvas do not block unrelated recipe runs. + + The ``llm-`` prefix matches the existing convention in + ``core/data_recipe/service.py::_recipe_has_llm_columns`` and covers all + LLM column types emitted by the frontend (llm-text, llm-code, + llm-structured, llm-judge). + """ + aliases: set[str] = set() + for column in recipe.get("columns", []): + if not isinstance(column, dict): + continue + column_type = column.get("column_type") + if not isinstance(column_type, str) or not column_type.startswith("llm-"): + continue + alias = column.get("model_alias") + if isinstance(alias, str) and alias: + aliases.add(alias) + return aliases + + +def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None: + """ + Mutate recipe dict in-place: for any provider with is_local=True, + generate a JWT and fill in the endpoint pointing at this server. + """ + providers = recipe.get("model_providers") + if not providers: + return + + # Collect local providers and pop is_local from ALL dicts unconditionally. + # Strict `is True` guard so malformed payloads (is_local: 1, + # is_local: "true") do not accidentally trigger the loopback rewrite. + local_indices: list[int] = [] + for i, provider in enumerate(providers): + if not isinstance(provider, dict): + continue + is_local = provider.pop("is_local", None) + if is_local is True: + local_indices.append(i) + + if not local_indices: + return + + endpoint = _resolve_local_v1_endpoint(request) + + # Only gate on model-loaded if a local provider is actually reachable + # from an LLM column through a model_config. Orphan model_config nodes + # that reference a local provider but that no LLM column uses should + # not block runs; the recipe would never call /v1 for them. + local_names = { + providers[i].get("name") for i in local_indices if providers[i].get("name") + } + used_aliases = _used_llm_model_aliases(recipe) + referenced_providers = { + mc.get("provider") + for mc in recipe.get("model_configs", []) + if ( + isinstance(mc, dict) + and mc.get("provider") + and mc.get("alias") in used_aliases + ) + } + + token = "" + if local_names & referenced_providers: + # Verify a model is loaded. + # NOTE: This is a point-in-time check (TOCTOU). The model could be unloaded + # or swapped after this check but before the recipe subprocess calls /v1. + # The inference endpoint returns a clear 400 in that case. + # + # Imports are deferred to avoid circular dependencies with inference modules. + from routes.inference import get_llama_cpp_backend + from core.inference import get_inference_backend + + llama = get_llama_cpp_backend() + model_loaded = llama.is_loaded + if not model_loaded: + backend = get_inference_backend() + model_loaded = bool(backend.active_model_name) + if not model_loaded: + raise ValueError( + "No model loaded in Chat. Load a model first, then run the recipe." + ) + + from auth.authentication import ( + create_access_token, + ) # deferred: avoids circular import + + # Uses the "unsloth" admin subject. If the user changes their password, + # the JWT secret rotates and this token becomes invalid mid-run. + # Acceptable for v1 - recipes typically finish well within one session. + token = create_access_token( + subject = "unsloth", + expires_delta = timedelta(hours = 24), + ) + + # Defensively strip any stale "external"-only fields the frontend may + # have left on the dict (extra_headers/extra_body/api_key_env). The UI + # hides these inputs in local mode but the payload builder still serializes + # them, so a previously external provider that flipped to local can carry + # invalid JSON or rogue auth headers into the local /v1 call. + for i in local_indices: + providers[i]["endpoint"] = endpoint + providers[i]["api_key"] = token + providers[i]["provider_type"] = "openai" + providers[i].pop("api_key_env", None) + providers[i].pop("extra_headers", None) + providers[i].pop("extra_body", None) + + # Force skip_health_check on any model_config that references a local + # provider. The local /v1/models endpoint only lists the real loaded + # model (e.g. "unsloth/llama-3.2-1b") and not the placeholder "local" + # that the recipe sends as the model id, so data_designer's pre-flight + # health check would otherwise fail before the first completion call. + # The backend route ignores the model id field in chat completions, so + # skipping the check is safe. + for mc in recipe.get("model_configs", []): + if not isinstance(mc, dict): + continue + if mc.get("provider") in local_names: + mc["skip_health_check"] = True + + def _normalize_run_name(value: Any) -> str | None: if value is None: return None @@ -40,7 +197,7 @@ def _normalize_run_name(value: Any) -> str | None: @router.post("/jobs", response_class = JSONResponse, response_model = JobCreateResponse) -def create_job(payload: RecipePayload): +def create_job(payload: RecipePayload, request: Request): recipe = payload.recipe if not recipe.get("columns"): raise HTTPException(status_code = 400, detail = "Recipe must include columns.") @@ -67,6 +224,11 @@ def create_job(payload: RecipePayload): status_code = 400, detail = f"invalid run_config: {exc}" ) from exc + try: + _inject_local_providers(recipe, request) + except ValueError as exc: + raise HTTPException(status_code = 400, detail = str(exc)) from exc + mgr = get_job_manager() try: job_id = mgr.start(recipe = recipe, run = run) diff --git a/studio/backend/routes/data_recipe/validate.py b/studio/backend/routes/data_recipe/validate.py index a793a3b172..555e3eaa06 100644 --- a/studio/backend/routes/data_recipe/validate.py +++ b/studio/backend/routes/data_recipe/validate.py @@ -68,6 +68,20 @@ def _collect_validation_errors(recipe: dict[str, Any]) -> list[ValidateError]: return errors +def _patch_local_providers(recipe: dict[str, Any]) -> None: + """Strip is_local and fill a dummy endpoint so validation doesn't choke. + + Uses a strict `is True` check to match _inject_local_providers in + jobs.py - malformed payloads with truthy but non-boolean is_local + values should not be treated as local. + """ + for provider in recipe.get("model_providers", []): + if not isinstance(provider, dict): + continue + if provider.pop("is_local", None) is True: + provider["endpoint"] = "http://127.0.0.1" + + @router.post("/validate", response_model = ValidateResponse) def validate(payload: RecipePayload) -> ValidateResponse: recipe = payload.recipe @@ -77,6 +91,8 @@ def validate(payload: RecipePayload) -> ValidateResponse: errors = [ValidateError(message = "Recipe must include columns.")], ) + _patch_local_providers(recipe) + try: validate_recipe(recipe) except RuntimeError as exc: diff --git a/studio/backend/run.py b/studio/backend/run.py index 9c3622988e..86c1194661 100644 --- a/studio/backend/run.py +++ b/studio/backend/run.py @@ -324,6 +324,14 @@ def run_server( _server = uvicorn.Server(config) _shutdown_event = Event() + # Expose the actual bound port so request-handling code can build + # loopback URLs that point at the real backend, not whatever port a + # reverse proxy or tunnel exposed in the request URL. Only publish + # an explicit value when we know the concrete port; for ephemeral + # binds (port==0) leave it unset and let request handlers fall back + # to the ASGI request scope or request.base_url. + app.state.server_port = port if port and port > 0 else None + # Run server in a daemon thread def _run(): asyncio.run(_server.serve()) diff --git a/studio/frontend/src/features/recipe-studio/blocks/render-dialog.tsx b/studio/frontend/src/features/recipe-studio/blocks/render-dialog.tsx index 10fcaa489d..92f72dfff1 100644 --- a/studio/frontend/src/features/recipe-studio/blocks/render-dialog.tsx +++ b/studio/frontend/src/features/recipe-studio/blocks/render-dialog.tsx @@ -28,6 +28,7 @@ export function renderBlockDialog( categoryOptions: SamplerConfig[], modelConfigAliases: string[], modelProviderOptions: string[], + localProviderNames: Set, toolProfileAliases: string[], datetimeOptions: string[], onUpdate: (id: string, patch: Partial) => void, @@ -109,6 +110,7 @@ export function renderBlockDialog( ) : null; diff --git a/studio/frontend/src/features/recipe-studio/components/inline/inline-model.tsx b/studio/frontend/src/features/recipe-studio/components/inline/inline-model.tsx index d3ca22b00b..16e99f4fae 100644 --- a/studio/frontend/src/features/recipe-studio/components/inline/inline-model.tsx +++ b/studio/frontend/src/features/recipe-studio/components/inline/inline-model.tsx @@ -10,11 +10,21 @@ type InlineModelPatch = Partial | Partial; type InlineModelProps = { config: ModelProviderConfig | ModelConfig; + localProviderNames?: Set; onUpdate: (patch: InlineModelPatch) => void; }; export function InlineModel(props: InlineModelProps): ReactElement { if (props.config.kind === "model_provider") { + if (props.config.is_local) { + return ( +
+ + Local model (Chat) + +
+ ); + } return (
@@ -42,21 +52,40 @@ export function InlineModel(props: InlineModelProps): ReactElement { ); } + // model_config branch - mirror the local-aware provider sync from the + // dialog path so inline edits do not leave stale "local" placeholders + // on external providers and fill the placeholder when switching to local. + const localNames = props.localProviderNames ?? new Set(); + const modelConfig = props.config; + const handleProviderChange = (nextProvider: string) => { + const isLocal = localNames.has(nextProvider); + if (isLocal && !modelConfig.model.trim()) { + props.onUpdate({ provider: nextProvider, model: "local" }); + return; + } + if (!isLocal && modelConfig.model === "local") { + props.onUpdate({ provider: nextProvider, model: "" }); + return; + } + props.onUpdate({ provider: nextProvider }); + }; + const isLinkedToLocal = localNames.has(modelConfig.provider); + return (
props.onUpdate({ provider: event.target.value })} + value={modelConfig.provider} + onChange={(event) => handleProviderChange(event.target.value)} /> props.onUpdate({ model: event.target.value })} /> @@ -65,7 +94,7 @@ export function InlineModel(props: InlineModelProps): ReactElement { className="nodrag h-8 w-full text-xs" type="number" placeholder="0.7" - value={props.config.inference_temperature ?? ""} + value={modelConfig.inference_temperature ?? ""} onChange={(event) => props.onUpdate({ // biome-ignore lint/style/useNamingConvention: api schema diff --git a/studio/frontend/src/features/recipe-studio/components/recipe-graph-node.tsx b/studio/frontend/src/features/recipe-studio/components/recipe-graph-node.tsx index 0fcf202190..8afc4e26bf 100644 --- a/studio/frontend/src/features/recipe-studio/components/recipe-graph-node.tsx +++ b/studio/frontend/src/features/recipe-studio/components/recipe-graph-node.tsx @@ -30,7 +30,7 @@ import { Position, useUpdateNodeInternals, } from "@xyflow/react"; -import { type ReactElement, memo, useEffect } from "react"; +import { type ReactElement, memo, useEffect, useMemo } from "react"; import { MAX_NODE_WIDTH, MAX_NOTE_NODE_WIDTH, @@ -287,6 +287,7 @@ function renderNodeBody( config: NodeConfig | undefined, summary: string, updateConfig: (id: string, patch: Partial) => void, + localProviderNames: Set, ): ReactElement { if (config?.kind === "markdown_note") { return ; @@ -300,7 +301,13 @@ function renderNodeBody( return ; } if (config.kind === "model_provider" || config.kind === "model_config") { - return ; + return ( + + ); } if (config.kind === "llm") { return ; @@ -355,6 +362,16 @@ function RecipeGraphNodeBase({ const config = useRecipeStudioStore((state) => state.configs[id]); const openConfig = useRecipeStudioStore((state) => state.openConfig); const updateConfig = useRecipeStudioStore((state) => state.updateConfig); + const allConfigs = useRecipeStudioStore((state) => state.configs); + const localProviderNames = useMemo(() => { + const names = new Set(); + for (const cfg of Object.values(allConfigs)) { + if (cfg.kind === "model_provider" && cfg.is_local === true) { + names.add(cfg.name); + } + } + return names; + }, [allConfigs]); const llmAuxVisible = useRecipeStudioStore( (state) => state.llmAuxVisibility[id] ?? false, ); @@ -418,7 +435,12 @@ function RecipeGraphNodeBase({ data.kind === "tool_config" || data.kind === "validator"; const summary = getConfigSummary(config); - const nodeBody = renderNodeBody(config, summary, updateConfig); + const nodeBody = renderNodeBody( + config, + summary, + updateConfig, + localProviderNames, + ); const canShowLlmAux = config?.kind === "llm" && (Boolean(config.prompt.trim()) || diff --git a/studio/frontend/src/features/recipe-studio/dialogs/config-dialog.tsx b/studio/frontend/src/features/recipe-studio/dialogs/config-dialog.tsx index 6457993c35..c62169956d 100644 --- a/studio/frontend/src/features/recipe-studio/dialogs/config-dialog.tsx +++ b/studio/frontend/src/features/recipe-studio/dialogs/config-dialog.tsx @@ -18,6 +18,7 @@ type ConfigDialogProps = { categoryOptions: SamplerConfig[]; modelConfigAliases: string[]; modelProviderOptions: string[]; + localProviderNames: Set; toolProfileAliases: string[]; datetimeOptions: string[]; onUpdate: (id: string, patch: Partial) => void; @@ -32,6 +33,7 @@ export function ConfigDialog({ categoryOptions, modelConfigAliases, modelProviderOptions, + localProviderNames, toolProfileAliases, datetimeOptions, onUpdate, @@ -101,6 +103,7 @@ export function ConfigDialog({ categoryOptions, modelConfigAliases, modelProviderOptions, + localProviderNames, toolProfileAliases, datetimeOptions, onUpdate, diff --git a/studio/frontend/src/features/recipe-studio/dialogs/models/model-config-dialog.tsx b/studio/frontend/src/features/recipe-studio/dialogs/models/model-config-dialog.tsx index 3192dc8fd9..368ae08acb 100644 --- a/studio/frontend/src/features/recipe-studio/dialogs/models/model-config-dialog.tsx +++ b/studio/frontend/src/features/recipe-studio/dialogs/models/model-config-dialog.tsx @@ -17,7 +17,7 @@ import { } from "@/components/ui/combobox"; import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; -import { type ReactElement, useRef, useState } from "react"; +import { type ReactElement, useEffect, useRef, useState } from "react"; import type { ModelConfig } from "../../types"; import { CollapsibleSectionTriggerButton } from "../shared/collapsible-section-trigger"; import { FieldLabel } from "../shared/field-label"; @@ -26,14 +26,17 @@ import { NameField } from "../shared/name-field"; type ModelConfigDialogProps = { config: ModelConfig; providerOptions: string[]; + localProviderNames: Set; onUpdate: (patch: Partial) => void; }; export function ModelConfigDialog({ config, providerOptions, + localProviderNames, onUpdate, }: ModelConfigDialogProps): ReactElement { + const isLinkedToLocal = localProviderNames.has(config.provider); const [optionalOpen, setOptionalOpen] = useState(false); const modelId = `${config.id}-model`; const providerId = `${config.id}-provider`; @@ -44,11 +47,13 @@ export function ModelConfigDialog({ const extraBodyId = `${config.id}-inference-extra-body`; const providerAnchorRef = useRef(null); const providerInputRef = useRef(config.provider); - const lastProviderRef = useRef(config.provider); - if (lastProviderRef.current !== config.provider) { - lastProviderRef.current = config.provider; + // Sync providerInputRef with the current provider value. Updating a ref in + // an effect (vs reading/writing it during render) satisfies the + // react-hooks/refs rule and keeps the combobox blur path stable across + // re-renders. + useEffect(() => { providerInputRef.current = config.provider; - } + }, [config.provider]); const updateField = ( key: K, value: ModelConfig[K], @@ -56,6 +61,21 @@ export function ModelConfigDialog({ onUpdate({ [key]: value } as Partial); }; + // Apply provider selection while keeping the local-provider model autofill + // consistent across both dropdown selection and free-typed + blur input. + const applyProviderChange = (selectedProvider: string) => { + const isLocal = localProviderNames.has(selectedProvider); + if (isLocal && !config.model.trim()) { + onUpdate({ provider: selectedProvider, model: "local" }); + return; + } + if (!isLocal && config.model === "local") { + onUpdate({ provider: selectedProvider, model: "" }); + return; + } + updateField("provider", selectedProvider); + }; + return (
updateField("provider", value ?? "")} + onValueChange={(value) => applyProviderChange(value ?? "")} onInputValueChange={(value) => { providerInputRef.current = value; }} @@ -98,7 +118,7 @@ export function ModelConfigDialog({ onBlur={() => { const next = providerInputRef.current; if (next !== config.provider) { - updateField("provider", next); + applyProviderChange(next); } }} /> @@ -124,12 +144,12 @@ export function ModelConfigDialog({ updateField("model", event.target.value)} /> diff --git a/studio/frontend/src/features/recipe-studio/dialogs/models/model-provider-dialog.tsx b/studio/frontend/src/features/recipe-studio/dialogs/models/model-provider-dialog.tsx index 897fc7ee36..ef7366a225 100644 --- a/studio/frontend/src/features/recipe-studio/dialogs/models/model-provider-dialog.tsx +++ b/studio/frontend/src/features/recipe-studio/dialogs/models/model-provider-dialog.tsx @@ -24,6 +24,7 @@ export function ModelProviderDialog({ onUpdate, }: ModelProviderDialogProps): ReactElement { const [optionalOpen, setOptionalOpen] = useState(false); + const isLocal = config.is_local ?? false; const endpointId = `${config.id}-endpoint`; const apiKeyEnvId = `${config.id}-api-key-env`; const apiKeyId = `${config.id}-api-key`; @@ -43,94 +44,163 @@ export function ModelProviderDialog({ value={config.name} onChange={(value) => onUpdate({ name: value })} /> -
-

- Start with the endpoint you want this model to use -

-

- Most connections only need an endpoint. Add an API key if that - service requires one. -

-
-
- - updateField("endpoint", event.target.value)} - /> -
+ + {/* Model source toggle */}
- - updateField("api_key", event.target.value)} - /> +

Model source

+
+ + +
- - - - - + + {isLocal ? ( +
+

+ Ready to go +

+

+ Recipes will use whatever model is loaded in the Chat tab when you + hit run. No endpoint or API key needed. +

+
+ ) : ( + <> +
+

+ Start with the endpoint you want this model to use +

+

+ Most connections only need an endpoint. Add an API key if that + service requires one. +

+
updateField("api_key_env", event.target.value)} + placeholder="https://..." + value={config.endpoint} + onChange={(event) => updateField("endpoint", event.target.value)} />
-