diff --git a/assistant/src/__tests__/credential-security-invariants.test.ts b/assistant/src/__tests__/credential-security-invariants.test.ts index 0b4ad8e6c3c..f473297749d 100644 --- a/assistant/src/__tests__/credential-security-invariants.test.ts +++ b/assistant/src/__tests__/credential-security-invariants.test.ts @@ -208,6 +208,7 @@ describe("Invariant 2: no generic plaintext secret read API", () => { "config/bundled-skills/image-studio/tools/media-generate-image.ts", // image generation tool API key lookup "config/bundled-skills/media-processing/tools/analyze-keyframes.ts", // keyframe analysis tool API key lookup "providers/registry.ts", // provider registry API key lookup for initialization + "providers/inference/resolve-auth.ts", // provider_connection auth resolver (api_key path reads vault, mirrors registry.ts) "providers/provider-availability.ts", // provider availability API key check "media/image-credentials.ts", // shared image-gen credential resolver (provider API key lookup) "memory/embedding-backend.ts", // embedding backend API key lookup diff --git a/assistant/src/cli/commands/__tests__/inference-send.test.ts b/assistant/src/cli/commands/__tests__/inference-send.test.ts index ddbd99c98d6..49e1aac981e 100644 --- a/assistant/src/cli/commands/__tests__/inference-send.test.ts +++ b/assistant/src/cli/commands/__tests__/inference-send.test.ts @@ -111,11 +111,13 @@ mock.module("../../../providers/provider-send-message.js", () => ({ })); mock.module("../../../config/loader.js", () => ({ - getConfigReadOnly: () => ({ - llm: { - profiles: mockProfileCatalog, - }, - }), + getConfig: () => ({ llm: { profiles: mockProfileCatalog } }), + getConfigReadOnly: () => ({ llm: { profiles: mockProfileCatalog } }), + loadConfig: () => ({ llm: { profiles: mockProfileCatalog } }), + loadRawConfig: () => ({}) as Record, + saveRawConfig: () => {}, + invalidateConfigCache: () => {}, + applyNestedDefaults: () => ({ llm: { profiles: mockProfileCatalog } }), })); mock.module("../../../util/logger.js", () => ({ diff --git a/assistant/src/cli/commands/inference-providers.ts b/assistant/src/cli/commands/inference-providers.ts new file mode 100644 index 00000000000..fca6aa95c2e --- /dev/null +++ b/assistant/src/cli/commands/inference-providers.ts @@ -0,0 +1,443 @@ +/** + * `assistant inference providers` CLI namespace. + * + * Provider-scoped admin commands. Currently exposes one subcommand: + * + * `assistant inference providers connections ` + * list — list all connections (optionally filtered by provider) + * get — show a single connection + * create — create a new connection + * update — update a connection's auth + * delete — delete a connection (rejects if profiles reference it) + * + * Future provider-scoped commands (capabilities, picklists, etc.) hang off + * the `providers` namespace alongside `connections`. + */ + +import type { Command } from "commander"; + +import { getConfigReadOnly } from "../../config/loader.js"; +import { getDb } from "../../memory/db-connection.js"; +import { AuthSchema, VALID_CONNECTION_PROVIDERS } from "../../providers/inference/auth.js"; +import { + createConnection, + deleteConnection, + getConnection, + listConnections, + updateConnection, +} from "../../providers/inference/connections.js"; +import { log } from "../logger.js"; + +// --------------------------------------------------------------------------- +// Formatting helpers +// --------------------------------------------------------------------------- + +function formatAuth(auth: ReturnType): string { + switch (auth.type) { + case "api_key": + return `api_key (credential: ${auth.credential})`; + case "platform": + return "platform (managed proxy)"; + case "none": + return "none"; + case "oauth_subscription": + return `oauth_subscription (credential: ${auth.credential})`; + case "service_account": + return `service_account (credential: ${auth.credential})`; + } +} + +// --------------------------------------------------------------------------- +// Subcommand: list +// --------------------------------------------------------------------------- + +function attachListSubcommand(connections: Command): void { + connections + .command("list") + .description("List all provider connections") + .option("--provider

", "Filter by provider") + .option("--json", "Output as JSON") + .action(async (opts: { provider?: string; json?: boolean }) => { + const db = getDb(); + const rows = listConnections(db, opts.provider ? { provider: opts.provider } : undefined); + + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: true, connections: rows }) + "\n"); + return; + } + + if (rows.length === 0) { + process.stdout.write("No connections found.\n"); + return; + } + + for (const conn of rows) { + process.stdout.write( + `${conn.name} provider=${conn.provider} auth=${formatAuth(conn.auth)}\n`, + ); + } + }); +} + +// --------------------------------------------------------------------------- +// Subcommand: get +// --------------------------------------------------------------------------- + +function attachGetSubcommand(connections: Command): void { + connections + .command("get ") + .description("Show a single provider connection") + .option("--json", "Output as JSON") + .action(async (name: string, opts: { json?: boolean }) => { + const db = getDb(); + const conn = getConnection(db, name); + + if (!conn) { + const msg = `Connection "${name}" not found.`; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: true, connection: conn }) + "\n"); + return; + } + + process.stdout.write(`name: ${conn.name}\n`); + process.stdout.write(`provider: ${conn.provider}\n`); + process.stdout.write(`auth: ${formatAuth(conn.auth)}\n`); + process.stdout.write( + `created: ${new Date(conn.createdAt).toISOString()}\n`, + ); + process.stdout.write( + `updated: ${new Date(conn.updatedAt).toISOString()}\n`, + ); + }); +} + +// --------------------------------------------------------------------------- +// Subcommand: create +// --------------------------------------------------------------------------- + +function attachCreateSubcommand(connections: Command): void { + connections + .command("create ") + .description("Create a new provider connection") + .requiredOption("--provider

", `Provider (${VALID_CONNECTION_PROVIDERS.join("|")})`) + .requiredOption("--auth ", "Auth type: api_key|platform|none") + .option("--credential ", "Vault credential name (required for --auth api_key)") + .option("--json", "Output as JSON") + .action( + async ( + name: string, + opts: { provider: string; auth: string; credential?: string; json?: boolean }, + ) => { + let authInput: unknown; + if (opts.auth === "api_key") { + if (!opts.credential) { + const msg = "--credential is required when --auth api_key"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "api_key", credential: opts.credential }; + } else if (opts.auth === "platform") { + if (opts.credential) { + const msg = "--credential is not accepted with --auth platform"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "platform" }; + } else if (opts.auth === "none") { + if (opts.credential) { + const msg = "--credential is not accepted with --auth none"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "none" }; + } else { + const msg = `Unknown auth type "${opts.auth}". Use: api_key, platform, none`; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + const authResult = AuthSchema.safeParse(authInput); + if (!authResult.success) { + const msg = `Invalid auth: ${authResult.error.message}`; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + const db = getDb(); + const result = createConnection(db, { + name, + provider: opts.provider, + auth: authResult.data, + }); + + if (!result.ok) { + let msg: string; + if (result.error.code === "already_exists") { + msg = `Connection "${name}" already exists. Use 'update' to modify it.`; + } else if (result.error.code === "invalid_provider") { + msg = `Invalid provider "${result.error.provider}". Valid: ${VALID_CONNECTION_PROVIDERS.join(", ")}`; + } else { + msg = "Invalid auth configuration."; + } + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + if (opts.json) { + process.stdout.write( + JSON.stringify({ ok: true, connection: result.connection }) + "\n", + ); + } else { + process.stdout.write( + `Created connection "${result.connection.name}" (provider=${result.connection.provider}, auth=${formatAuth(result.connection.auth)})\n`, + ); + } + }, + ); +} + +// --------------------------------------------------------------------------- +// Subcommand: update +// --------------------------------------------------------------------------- + +function attachUpdateSubcommand(connections: Command): void { + connections + .command("update ") + .description("Update a connection's auth") + .requiredOption("--auth ", "Auth type: api_key|platform|none") + .option("--credential ", "Vault credential name (required for --auth api_key)") + .option("--json", "Output as JSON") + .action( + async ( + name: string, + opts: { auth: string; credential?: string; json?: boolean }, + ) => { + let authInput: unknown; + if (opts.auth === "api_key") { + if (!opts.credential) { + const msg = "--credential is required when --auth api_key"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "api_key", credential: opts.credential }; + } else if (opts.auth === "platform") { + if (opts.credential) { + const msg = "--credential is not accepted with --auth platform"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "platform" }; + } else if (opts.auth === "none") { + if (opts.credential) { + const msg = "--credential is not accepted with --auth none"; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + authInput = { type: "none" }; + } else { + const msg = `Unknown auth type "${opts.auth}". Use: api_key, platform, none`; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + const authResult = AuthSchema.safeParse(authInput); + if (!authResult.success) { + const msg = `Invalid auth: ${authResult.error.message}`; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + const db = getDb(); + const result = updateConnection(db, name, { auth: authResult.data }); + + if (!result.ok) { + const msg = + result.error.code === "not_found" + ? `Connection "${name}" not found.` + : "Invalid auth configuration."; + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + if (opts.json) { + process.stdout.write( + JSON.stringify({ ok: true, connection: result.connection }) + "\n", + ); + } else { + process.stdout.write( + `Updated connection "${name}" auth to ${formatAuth(result.connection.auth)}\n`, + ); + } + }, + ); +} + +// --------------------------------------------------------------------------- +// Subcommand: delete +// --------------------------------------------------------------------------- + +function attachDeleteSubcommand(connections: Command): void { + connections + .command("delete ") + .description("Delete a provider connection") + .option("--force", "Delete even if profiles reference this connection") + .option("--json", "Output as JSON") + .action(async (name: string, opts: { force?: boolean; json?: boolean }) => { + const db = getDb(); + + // Find profiles referencing this connection. + const config = getConfigReadOnly(); + const profiles = config.llm?.profiles ?? {}; + const referencingProfiles = Object.entries(profiles) + .filter(([, p]) => (p as Record).provider_connection === name) + .map(([profileName]) => profileName); + + const result = deleteConnection(db, name, { + force: opts.force, + referencingProfiles, + }); + + if (!result.ok) { + let msg: string; + if (result.error.code === "not_found") { + msg = `Connection "${name}" not found.`; + } else if (result.error.code === "has_references") { + msg = + `Connection "${name}" is referenced by ${result.error.count} profile(s): ` + + `${referencingProfiles.join(", ")}. ` + + "Use --force to delete anyway (profiles will error at next inference call)."; + } else { + msg = "Delete failed."; + } + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: false, error: msg }) + "\n"); + } else { + log.error(msg); + } + process.exitCode = 1; + return; + } + + if (opts.json) { + process.stdout.write(JSON.stringify({ ok: true }) + "\n"); + } else { + process.stdout.write(`Deleted connection "${name}"\n`); + if (referencingProfiles.length > 0 && opts.force) { + process.stdout.write( + `Warning: ${referencingProfiles.length} profile(s) now reference a missing connection: ` + + `${referencingProfiles.join(", ")}\n`, + ); + } + } + }); +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +export function attachProvidersSubcommand(inference: Command): void { + const providers = inference + .command("providers") + .description("Inference provider admin commands"); + + const connections = providers + .command("connections") + .description("Manage provider connections (auth configs for inference)"); + + connections.addHelpText( + "after", + ` +Provider connections map a name to a (provider, auth) pair. +Profiles reference connections via the 'provider_connection' field. + +Canonical connections (seeded on every boot): + anthropic-managed → provider=anthropic, auth=platform + openai-managed → provider=openai, auth=platform + gemini-managed → provider=gemini, auth=platform + ollama-local → provider=ollama, auth=none + +Examples: + $ assistant inference providers connections list + $ assistant inference providers connections get anthropic-managed + $ assistant inference providers connections create anthropic-personal \\ + --provider anthropic --auth api_key --credential credential/anthropic/api_key + $ assistant inference providers connections update anthropic-personal --auth platform + $ assistant inference providers connections delete anthropic-personal`, + ); + + attachListSubcommand(connections); + attachGetSubcommand(connections); + attachCreateSubcommand(connections); + attachUpdateSubcommand(connections); + attachDeleteSubcommand(connections); +} diff --git a/assistant/src/cli/commands/inference.ts b/assistant/src/cli/commands/inference.ts index 92d863562ad..52692f700f9 100644 --- a/assistant/src/cli/commands/inference.ts +++ b/assistant/src/cli/commands/inference.ts @@ -9,6 +9,7 @@ import { userMessage, } from "../../providers/provider-send-message.js"; import { log } from "../logger.js"; +import { attachProvidersSubcommand } from "./inference-providers.js"; import { attachSessionSubcommand } from "./inference-session.js"; /** @@ -218,6 +219,7 @@ Examples: attachSendSubcommand(inference); attachSessionSubcommand(inference); + attachProvidersSubcommand(inference); const llm = program .command("llm") diff --git a/assistant/src/config/schemas/llm.ts b/assistant/src/config/schemas/llm.ts index b642505660c..349890014de 100644 --- a/assistant/src/config/schemas/llm.ts +++ b/assistant/src/config/schemas/llm.ts @@ -324,6 +324,14 @@ export const ProfileEntry = LLMConfigFragment.extend({ source: ProfileSource.optional(), label: z.string().min(1).optional(), description: z.string().optional(), + /** + * Name of a `provider_connections` row to use for this profile. + * When set, the dispatcher resolves auth from the connection instead of + * the global `services.inference.mode` toggle. Additive alongside the + * legacy `provider` + `source` fields; those remain as read-only + * deprecated fallbacks for profiles not yet backfilled. + */ + provider_connection: z.string().min(1).optional(), }); export type ProfileEntry = z.infer; diff --git a/assistant/src/daemon/lifecycle.ts b/assistant/src/daemon/lifecycle.ts index 56ae3a22839..d97c0b74a75 100644 --- a/assistant/src/daemon/lifecycle.ts +++ b/assistant/src/daemon/lifecycle.ts @@ -44,6 +44,7 @@ import { } from "../memory/attachments-store.js"; import { expireAllPendingCanonicalRequests } from "../memory/canonical-guardian-store.js"; import { deleteMessageById, getMessages } from "../memory/conversation-crud.js"; +import { getDb } from "../memory/db-connection.js"; import { initializeDb } from "../memory/db-init.js"; import { selectEmbeddingBackend } from "../memory/embedding-backend.js"; import { enqueueMemoryJob } from "../memory/jobs-store.js"; @@ -60,6 +61,7 @@ import { seedOAuthProviders } from "../oauth/seed-providers.js"; import { loadUserPlugins } from "../plugins/user-loader.js"; import { backfillGuardIfNeeded } from "../proactive-artifact/index.js"; import { ensurePromptFiles } from "../prompts/system-prompt.js"; +import { runProviderConnectionsBackfill } from "../providers/inference/backfill.js"; import { resolveManagedProxyContext } from "../providers/managed-proxy/context.js"; import { broadcastMessage } from "../runtime/assistant-event-hub.js"; import { @@ -364,6 +366,20 @@ export async function runDaemon(): Promise { } } + // Seed canonical inference provider_connections and backfill any legacy + // profiles that pre-date the connection field. Idempotent — runs every + // boot so new canonicals propagate and manual config.json edits self-heal. + if (dbReady) { + try { + runProviderConnectionsBackfill(getDb()); + } catch (err) { + log.warn( + { err }, + "provider_connections backfill failed — continuing startup", + ); + } + } + if (dbReady) { await runWorkspaceMigrations(getWorkspaceDir(), WORKSPACE_MIGRATIONS); log.info("Daemon startup: workspace migrations complete"); diff --git a/assistant/src/memory/db-init.ts b/assistant/src/memory/db-init.ts index 0386847d7a0..9c3bddd570b 100644 --- a/assistant/src/memory/db-init.ts +++ b/assistant/src/memory/db-init.ts @@ -73,6 +73,7 @@ import { migrateCreateMemoryGraphNodeEdits, migrateCreateMemoryGraphTables, migrateCreateMemoryRecallLogs, + migrateCreateProviderConnections, migrateCreateThreadStartersTable, migrateCreateTraceEventsTable, migrateDeletePrivateConversations, @@ -416,6 +417,7 @@ export function initializeDb(): void { migrateTraceEventsCreatedAtIndex, migrateConversationInferenceProfileSession, migrateMessageBookmarks, + migrateCreateProviderConnections, ]; // Run each migration step, catching and logging individual failures so one diff --git a/assistant/src/memory/migrations/243-provider-connections.ts b/assistant/src/memory/migrations/243-provider-connections.ts new file mode 100644 index 00000000000..c9a4a7efd6d --- /dev/null +++ b/assistant/src/memory/migrations/243-provider-connections.ts @@ -0,0 +1,70 @@ +import { type DrizzleDb, getSqliteFrom } from "../db-connection.js"; + +/** + * Creates the `provider_connections` table and seeds the four canonical + * connections that every installation ships with. + * + * Canonical connections: + * - anthropic-managed → provider=anthropic, auth={type:platform} + * - openai-managed → provider=openai, auth={type:platform} + * - gemini-managed → provider=gemini, auth={type:platform} + * - ollama-local → provider=ollama, auth={type:none} + * + * Idempotent: checks sqlite_master for the table before running DDL; + * canonical rows are inserted with INSERT OR IGNORE. + */ +export function migrateCreateProviderConnections(database: DrizzleDb): void { + const raw = getSqliteFrom(database); + + const tableExists = raw + .query( + `SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'provider_connections'`, + ) + .get(); + + if (!tableExists) { + try { + raw.exec("BEGIN"); + + raw.exec(/*sql*/ ` + CREATE TABLE provider_connections ( + name TEXT PRIMARY KEY, + provider TEXT NOT NULL, + auth TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) + `); + + raw.exec(/*sql*/ ` + CREATE INDEX idx_provider_connections_provider + ON provider_connections(provider) + `); + + raw.exec("COMMIT"); + } catch (e) { + try { + raw.exec("ROLLBACK"); + } catch { + /* no active transaction */ + } + throw e; + } + } + + // Seed canonical connections — idempotent via INSERT OR IGNORE. + const now = Date.now(); + const canonicals = [ + { name: "anthropic-managed", provider: "anthropic", auth: JSON.stringify({ type: "platform" }) }, + { name: "openai-managed", provider: "openai", auth: JSON.stringify({ type: "platform" }) }, + { name: "gemini-managed", provider: "gemini", auth: JSON.stringify({ type: "platform" }) }, + { name: "ollama-local", provider: "ollama", auth: JSON.stringify({ type: "none" }) }, + ]; + + for (const { name, provider, auth } of canonicals) { + raw.run( + `INSERT OR IGNORE INTO provider_connections (name, provider, auth, created_at, updated_at) VALUES (?, ?, ?, ?, ?)`, + [name, provider, auth, now, now], + ); + } +} diff --git a/assistant/src/memory/migrations/index.ts b/assistant/src/memory/migrations/index.ts index 9d32c9c9911..1395bc2a1ed 100644 --- a/assistant/src/memory/migrations/index.ts +++ b/assistant/src/memory/migrations/index.ts @@ -204,6 +204,7 @@ export { migrateTraceEventsCreatedAtIndex } from "./239-trace-events-created-at- export { migrateConversationInferenceProfileSession } from "./240-conversation-inference-profile-session.js"; export { migrateActivationStateFkCascade } from "./241-activation-state-fk-cascade.js"; export { migrateMessageBookmarks } from "./242-message-bookmarks.js"; +export { migrateCreateProviderConnections } from "./243-provider-connections.js"; export { MIGRATION_REGISTRY, type MigrationRegistryEntry, diff --git a/assistant/src/memory/schema/index.ts b/assistant/src/memory/schema/index.ts index 7302a99cc64..89bc61ffcc8 100644 --- a/assistant/src/memory/schema/index.ts +++ b/assistant/src/memory/schema/index.ts @@ -4,6 +4,7 @@ export * from "./calls.js"; export * from "./contacts.js"; export * from "./conversations.js"; export * from "./guardian.js"; +export * from "./inference.js"; export * from "./infrastructure.js"; export * from "./memory-core.js"; export * from "./memory-graph.js"; diff --git a/assistant/src/memory/schema/inference.ts b/assistant/src/memory/schema/inference.ts new file mode 100644 index 00000000000..b5d85c60967 --- /dev/null +++ b/assistant/src/memory/schema/inference.ts @@ -0,0 +1,27 @@ +import { index, integer, sqliteTable, text } from "drizzle-orm/sqlite-core"; + +/** + * Named provider connections. + * + * Each row is a named auth-config instance for a code-defined provider. + * Profiles in config.json reference connections by `name` via the + * `provider_connection` field. + * + * Created by migration 243. + */ +export const providerConnections = sqliteTable( + "provider_connections", + { + name: text("name").primaryKey(), + provider: text("provider").notNull(), + auth: text("auth").notNull(), + createdAt: integer("created_at").notNull(), + updatedAt: integer("updated_at").notNull(), + }, + (table) => [ + index("idx_provider_connections_provider").on(table.provider), + ], +); + +export type ProviderConnectionRow = typeof providerConnections.$inferSelect; +export type NewProviderConnectionRow = typeof providerConnections.$inferInsert; diff --git a/assistant/src/memory/v2/__tests__/backfill-jobs.test.ts b/assistant/src/memory/v2/__tests__/backfill-jobs.test.ts index 6e73cdf222e..2d13505ce63 100644 --- a/assistant/src/memory/v2/__tests__/backfill-jobs.test.ts +++ b/assistant/src/memory/v2/__tests__/backfill-jobs.test.ts @@ -115,7 +115,10 @@ const STUB_RUNTIME_CONFIG = { }; mock.module("../../../config/loader.js", () => ({ getConfig: () => STUB_RUNTIME_CONFIG, + getConfigReadOnly: () => STUB_RUNTIME_CONFIG, loadConfig: () => STUB_RUNTIME_CONFIG, + loadRawConfig: () => ({}) as Record, + saveRawConfig: () => {}, invalidateConfigCache: () => {}, applyNestedDefaults: () => STUB_RUNTIME_CONFIG, })); diff --git a/assistant/src/providers/__tests__/inference.test.ts b/assistant/src/providers/__tests__/inference.test.ts new file mode 100644 index 00000000000..9db513015ee --- /dev/null +++ b/assistant/src/providers/__tests__/inference.test.ts @@ -0,0 +1,287 @@ +/** + * Tests for provider_connections: migration, CRUD, and + * mix-and-match E2E (two profiles, same provider, different connections). + */ + +import { Database } from "bun:sqlite"; +import { describe, expect, test } from "bun:test"; + +import { drizzle } from "drizzle-orm/bun-sqlite"; + +import type { DrizzleDb } from "../../memory/db-connection.js"; +import { getSqliteFrom } from "../../memory/db-connection.js"; +import { migrateCreateProviderConnections } from "../../memory/migrations/243-provider-connections.js"; +import * as schema from "../../memory/schema.js"; +import { AuthSchema } from "../inference/auth.js"; +import { + createConnection, + deleteConnection, + getConnection, + listConnections, + seedCanonicalConnections, + updateConnection, +} from "../inference/connections.js"; + +// --------------------------------------------------------------------------- +// Setup — each test gets a fresh in-memory DB +// --------------------------------------------------------------------------- + +function setupDb(): { db: DrizzleDb; raw: Database } { + const sqlite = new Database(":memory:"); + sqlite.exec("PRAGMA journal_mode=WAL"); + sqlite.exec("PRAGMA foreign_keys = ON"); + const db = drizzle(sqlite, { schema }); + const raw = getSqliteFrom(db); + migrateCreateProviderConnections(db); + return { db, raw }; +} + +// --------------------------------------------------------------------------- +// Migration idempotency +// --------------------------------------------------------------------------- + +describe("migrateCreateProviderConnections", () => { + test("creates the provider_connections table", () => { + const { raw } = setupDb(); + const rows = raw.query("SELECT name FROM provider_connections").all() as { name: string }[]; + expect(Array.isArray(rows)).toBe(true); + }); + + test("seeds canonical connections on first run", () => { + const { db } = setupDb(); + const canonicals = ["anthropic-managed", "openai-managed", "gemini-managed", "ollama-local"]; + for (const name of canonicals) { + const conn = getConnection(db, name); + expect(conn).not.toBeNull(); + } + }); + + test("canonical connections have correct auth types", () => { + const { db } = setupDb(); + expect(getConnection(db, "anthropic-managed")?.auth.type).toBe("platform"); + expect(getConnection(db, "openai-managed")?.auth.type).toBe("platform"); + expect(getConnection(db, "gemini-managed")?.auth.type).toBe("platform"); + expect(getConnection(db, "ollama-local")?.auth.type).toBe("none"); + }); + + test("seedCanonicalConnections is idempotent", () => { + const { db } = setupDb(); + // Run twice — should not throw or create duplicates + seedCanonicalConnections(db); + seedCanonicalConnections(db); + const managed = listConnections(db, { provider: "anthropic" }); + expect(managed.filter((c) => c.name === "anthropic-managed").length).toBe(1); + }); +}); + +// --------------------------------------------------------------------------- +// Connection CRUD +// --------------------------------------------------------------------------- + +describe("Connection CRUD", () => { + test("createConnection — happy path", () => { + const { db } = setupDb(); + const result = createConnection(db, { + name: "my-anthropic", + provider: "anthropic", + auth: { type: "api_key", credential: "credential/anthropic/api_key" }, + }); + expect(result.ok).toBe(true); + if (!result.ok) return; + expect(result.connection.name).toBe("my-anthropic"); + expect(result.connection.provider).toBe("anthropic"); + expect(result.connection.auth.type).toBe("api_key"); + }); + + test("createConnection — rejects unknown provider", () => { + const { db } = setupDb(); + const result = createConnection(db, { + name: "bad-conn", + provider: "unknown-llm" as never, + auth: { type: "none" }, + }); + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.error.code).toBe("invalid_provider"); + }); + + test("createConnection — rejects duplicate name", () => { + const { db } = setupDb(); + createConnection(db, { + name: "dup-conn", + provider: "openai", + auth: { type: "platform" }, + }); + const result = createConnection(db, { + name: "dup-conn", + provider: "openai", + auth: { type: "platform" }, + }); + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.error.code).toBe("already_exists"); + }); + + test("getConnection — returns null for unknown name", () => { + const { db } = setupDb(); + expect(getConnection(db, "nonexistent")).toBeNull(); + }); + + test("listConnections — filters by provider", () => { + const { db } = setupDb(); + createConnection(db, { + name: "test-openai", + provider: "openai", + auth: { type: "api_key", credential: "credential/openai/api_key" }, + }); + const openai = listConnections(db, { provider: "openai" }); + expect(openai.every((c) => c.provider === "openai")).toBe(true); + }); + + test("updateConnection — happy path", () => { + const { db } = setupDb(); + createConnection(db, { + name: "updatable", + provider: "anthropic", + auth: { type: "platform" }, + }); + const result = updateConnection(db, "updatable", { + auth: { type: "api_key", credential: "credential/anthropic/api_key" }, + }); + expect(result.ok).toBe(true); + if (!result.ok) return; + expect(result.connection.auth.type).toBe("api_key"); + const fetched = getConnection(db, "updatable"); + expect(fetched?.auth.type).toBe("api_key"); + }); + + test("updateConnection — rejects unknown name", () => { + const { db } = setupDb(); + const result = updateConnection(db, "ghost", { auth: { type: "none" } }); + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.error.code).toBe("not_found"); + }); + + test("deleteConnection — happy path", () => { + const { db } = setupDb(); + createConnection(db, { + name: "to-delete", + provider: "gemini", + auth: { type: "platform" }, + }); + const result = deleteConnection(db, "to-delete"); + expect(result.ok).toBe(true); + expect(getConnection(db, "to-delete")).toBeNull(); + }); + + test("deleteConnection — rejects unknown name", () => { + const { db } = setupDb(); + const result = deleteConnection(db, "ghost"); + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.error.code).toBe("not_found"); + }); + + test("deleteConnection — rejects when profiles reference it (no --force)", () => { + const { db } = setupDb(); + createConnection(db, { + name: "referenced", + provider: "anthropic", + auth: { type: "platform" }, + }); + const result = deleteConnection(db, "referenced", { + force: false, + referencingProfiles: ["profile-a", "profile-b"], + }); + expect(result.ok).toBe(false); + if (result.ok) return; + expect(result.error.code).toBe("has_references"); + if (result.error.code !== "has_references") return; + expect(result.error.count).toBe(2); + }); + + test("deleteConnection --force removes even with references", () => { + const { db } = setupDb(); + createConnection(db, { + name: "force-delete", + provider: "anthropic", + auth: { type: "platform" }, + }); + const result = deleteConnection(db, "force-delete", { + force: true, + referencingProfiles: ["some-profile"], + }); + expect(result.ok).toBe(true); + expect(getConnection(db, "force-delete")).toBeNull(); + }); +}); + +// --------------------------------------------------------------------------- +// Auth schema validation +// --------------------------------------------------------------------------- + +describe("AuthSchema", () => { + test("api_key variant requires credential", () => { + const ok = AuthSchema.safeParse({ type: "api_key", credential: "cred/foo/api_key" }); + expect(ok.success).toBe(true); + + const bad = AuthSchema.safeParse({ type: "api_key" }); // missing credential + expect(bad.success).toBe(false); + }); + + test("platform variant has no extra fields", () => { + const ok = AuthSchema.safeParse({ type: "platform" }); + expect(ok.success).toBe(true); + }); + + test("none variant parses", () => { + const ok = AuthSchema.safeParse({ type: "none" }); + expect(ok.success).toBe(true); + }); + + test("oauth_subscription and service_account parse (v2 variants, runtime-rejected)", () => { + expect( + AuthSchema.safeParse({ type: "oauth_subscription", credential: "x" }).success, + ).toBe(true); + expect( + AuthSchema.safeParse({ type: "service_account", credential: "x" }).success, + ).toBe(true); + }); +}); + +// --------------------------------------------------------------------------- +// Mix-and-match correctness +// --------------------------------------------------------------------------- + +describe("Mix-and-match: two profiles, same provider, different connections", () => { + test("getConnection returns the right auth for each connection name", () => { + const { db } = setupDb(); + + // anthropic-managed already exists (canonical seed) with platform auth. + const managedConn = getConnection(db, "anthropic-managed"); + expect(managedConn?.auth.type).toBe("platform"); + + // Create a personal connection with api_key auth. + createConnection(db, { + name: "anthropic-personal", + provider: "anthropic", + auth: { type: "api_key", credential: "credential/anthropic/api_key" }, + }); + + const personalConn = getConnection(db, "anthropic-personal"); + expect(personalConn?.auth.type).toBe("api_key"); + + // Both connections exist for the same provider. + const anthropicConns = listConnections(db, { provider: "anthropic" }); + const names = anthropicConns.map((c) => c.name); + expect(names).toContain("anthropic-managed"); + expect(names).toContain("anthropic-personal"); + + // Auth is distinct per connection. + const managed = anthropicConns.find((c) => c.name === "anthropic-managed"); + const personal = anthropicConns.find((c) => c.name === "anthropic-personal"); + expect(managed?.auth.type).toBe("platform"); + expect(personal?.auth.type).toBe("api_key"); + }); +}); diff --git a/assistant/src/providers/inference/adapter-factory.ts b/assistant/src/providers/inference/adapter-factory.ts new file mode 100644 index 00000000000..ad0f3235cce --- /dev/null +++ b/assistant/src/providers/inference/adapter-factory.ts @@ -0,0 +1,142 @@ +/** + * Creates provider adapter instances from a resolved auth + connection. + * + * Adapters are created per-call when dispatching through a named + * `provider_connection`, enabling mix-and-match auth (e.g. managed and + * your-own Anthropic connections coexisting in the same registry). + */ + +import { AnthropicProvider } from "../anthropic/client.js"; +import { FireworksProvider } from "../fireworks/client.js"; +import { GeminiProvider } from "../gemini/client.js"; +import { OllamaProvider } from "../ollama/client.js"; +import { OpenAIResponsesProvider } from "../openai/responses-provider.js"; +import { OpenRouterProvider } from "../openrouter/client.js"; +import { RetryProvider } from "../retry.js"; +import type { Provider } from "../types.js"; +import { UsageTrackingProvider } from "../usage-tracking.js"; +import type { ResolvedAuth } from "./auth.js"; +import type { ProviderConnection } from "./auth.js"; + +export interface AdapterOptions { + model: string; + streamTimeoutMs?: number; + useNativeWebSearch?: boolean; +} + +/** + * Build a Provider instance for a given connection + resolved auth. + * + * Returns null when the provider/auth combination is not usable + * (e.g. `none` auth on a keyed provider). The caller decides whether to + * log a warning or fall back to the global registry. + */ +export function createAdapterFromConnection( + connection: ProviderConnection, + resolvedAuth: ResolvedAuth, + opts: AdapterOptions, +): Provider | null { + const { provider } = connection; + const { model, streamTimeoutMs = 1_800_000, useNativeWebSearch = false } = opts; + + let adapter: Provider | null = null; + + switch (provider) { + case "anthropic": { + if (resolvedAuth.kind === "none") return null; + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : ""; + adapter = new AnthropicProvider(apiKey, model, { + useNativeWebSearch, + streamTimeoutMs, + ...(resolvedAuth.kind === "header" && resolvedAuth.baseUrl + ? { baseURL: resolvedAuth.baseUrl } + : {}), + }); + break; + } + + case "openai": { + if (resolvedAuth.kind === "none") return null; + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : ""; + adapter = new OpenAIResponsesProvider(apiKey, model, { + useNativeWebSearch, + streamTimeoutMs, + ...(resolvedAuth.kind === "header" && resolvedAuth.baseUrl + ? { baseURL: resolvedAuth.baseUrl } + : {}), + }); + break; + } + + case "gemini": { + if (resolvedAuth.kind === "none") return null; + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : ""; + adapter = new GeminiProvider(apiKey, model, { + streamTimeoutMs, + ...(resolvedAuth.kind === "header" && resolvedAuth.baseUrl + ? { managedBaseUrl: resolvedAuth.baseUrl } + : {}), + }); + break; + } + + case "ollama": { + // Ollama supports no-auth operation; header auth is also accepted (API key param). + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : undefined; + adapter = new OllamaProvider(model, { + apiKey: apiKey ?? undefined, + streamTimeoutMs, + }); + break; + } + + case "fireworks": { + if (resolvedAuth.kind === "none") return null; + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : ""; + adapter = new FireworksProvider(apiKey, model, { streamTimeoutMs }); + break; + } + + case "openrouter": { + if (resolvedAuth.kind === "none") return null; + const apiKey = + resolvedAuth.kind === "header" + ? (resolvedAuth.headers["Authorization"] ?? "").replace(/^Bearer /, "") + : ""; + adapter = new OpenRouterProvider(apiKey, model, { + useNativeWebSearch, + streamTimeoutMs, + }); + break; + } + + default: + return null; + } + + if (!adapter) return null; + + const isProxy = + resolvedAuth.kind === "header" && resolvedAuth.baseUrl !== undefined; + + return new UsageTrackingProvider( + new RetryProvider(adapter, { + forwardUsageAttributionHeaders: isProxy, + }), + ); +} diff --git a/assistant/src/providers/inference/auth.ts b/assistant/src/providers/inference/auth.ts new file mode 100644 index 00000000000..f29a8363394 --- /dev/null +++ b/assistant/src/providers/inference/auth.ts @@ -0,0 +1,85 @@ +import { z } from "zod"; + +// --------------------------------------------------------------------------- +// Auth discriminated union (stored in provider_connections.auth as JSON) +// --------------------------------------------------------------------------- + +/** + * Auth configuration stored in the `provider_connections` table. + * + * v1 runtime-supported variants: + * - api_key: look up `credential` in vault, inject as bearer/provider header. + * - platform: route via Vellum managed proxy; no client-side credential. + * - none: no auth (e.g. Ollama running locally). + * + * v2 schema-accepted variants (runtime rejects with a clear "not yet shipped" error): + * - oauth_subscription: OAuth-based subscription auth. + * - service_account: service-account credentials (Vertex AI, Bedrock). + */ +export const AuthSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal("api_key"), + credential: z.string().min(1), + }), + z.object({ + type: z.literal("platform"), + }), + z.object({ + type: z.literal("none"), + }), + z.object({ + type: z.literal("oauth_subscription"), + credential: z.string().min(1), + }), + z.object({ + type: z.literal("service_account"), + credential: z.string().min(1), + }), +]); + +export type Auth = z.infer; + +// --------------------------------------------------------------------------- +// ResolvedAuth — what the dispatcher hands to each adapter +// --------------------------------------------------------------------------- + +/** + * The resolved form of an Auth, produced by the dispatcher before calling + * an adapter. Adapters are pure functions of (ResolvedAuth, request) → response + * and never access the vault themselves. + */ +export type ResolvedAuth = + | { kind: "header"; headers: Record; baseUrl?: string } + | { kind: "runtime_proxy"; route: string } + | { kind: "none" }; + +// --------------------------------------------------------------------------- +// Valid provider identifiers (code-defined closed set) +// --------------------------------------------------------------------------- + +export const VALID_CONNECTION_PROVIDERS = [ + "anthropic", + "openai", + "gemini", + "ollama", + "fireworks", + "openrouter", +] as const; + +export type ConnectionProvider = typeof VALID_CONNECTION_PROVIDERS[number]; + +export const ConnectionProviderSchema = z.enum(VALID_CONNECTION_PROVIDERS); + +// --------------------------------------------------------------------------- +// Full connection shape used by CRUD layer +// --------------------------------------------------------------------------- + +export const ProviderConnectionSchema = z.object({ + name: z.string().min(1), + provider: ConnectionProviderSchema, + auth: AuthSchema, + createdAt: z.number().int(), + updatedAt: z.number().int(), +}); + +export type ProviderConnection = z.infer; diff --git a/assistant/src/providers/inference/backfill.ts b/assistant/src/providers/inference/backfill.ts new file mode 100644 index 00000000000..c6e958af978 --- /dev/null +++ b/assistant/src/providers/inference/backfill.ts @@ -0,0 +1,130 @@ +/** + * Boot-time backfill: migrates existing config.json profiles from the legacy + * `provider` + `source` model to the new `provider_connection` model. + * + * Idempotent: profiles that already have `provider_connection` are skipped. + * Only modifies config.json when at least one profile needs updating. + */ + +import { loadRawConfig, saveRawConfig } from "../../config/loader.js"; +import type { DrizzleDb } from "../../memory/db-connection.js"; +import { credentialKey } from "../../security/credential-key.js"; +import { getLogger } from "../../util/logger.js"; +import { createConnection, getConnection, seedCanonicalConnections } from "./connections.js"; + +const log = getLogger("provider-connections-backfill"); + +// Providers that support the managed (platform) auth type. +const MANAGED_PROVIDERS = new Set(["anthropic", "openai", "gemini"]); + +/** + * Seed canonical provider_connections and backfill any legacy profiles that + * pre-date the connection field. + * + * Runs on every daemon boot — both halves are idempotent and cheap (O(profiles), + * typically ≤10). Designed to: + * - propagate new canonical connections as they're added in future versions + * - self-heal manual config.json edits that drop the connection field + * + * Steps: + * 1. Seed canonical connections (INSERT … ON CONFLICT DO NOTHING). + * 2. Walk `llm.profiles.*` in config.json. + * 3. For each profile without `provider_connection`, derive one from + * the profile's `source` + `provider` fields and write it back. + * 4. Save config.json if any profiles were updated. + */ +export function runProviderConnectionsBackfill(db: DrizzleDb): void { + try { + seedCanonicalConnections(db); + backfillConfigProfiles(db); + } catch (err) { + log.error({ err }, "provider_connections backfill failed — will retry on next boot"); + } +} + +function backfillConfigProfiles(db: DrizzleDb): void { + const raw = loadRawConfig(); + const llm = raw.llm as Record | undefined; + if (!llm) return; + + const profiles = llm.profiles as Record | undefined; + if (!profiles || typeof profiles !== "object") return; + + // Route on the auth axis (`services.inference.mode`), not the ownership + // axis (`profile.source` is `managed`/`user`, system-vs-user-created). + // Conflating them would regress user-owned profiles in managed + // deployments to require local API keys. + // + // We must mirror `loadConfig()`'s deployment-context default here: on + // platform-managed daemons the file may omit `services.inference.mode` + // and rely on `IS_PLATFORM=true → managed` to be filled in by + // `getDeploymentContextDefaults()` at runtime. Reading raw config alone + // would default missing values to `"your-own"`, which would backfill + // every profile to a `*-personal` connection and bake incorrect auth + // routing into config.json for later connection-based dispatch. + const inferenceMode = (raw.services as Record | undefined) + ?.inference as Record | undefined; + const onDiskMode = inferenceMode?.mode as string | undefined; + const isPlatform = + process.env.IS_PLATFORM === "true" || process.env.IS_PLATFORM === "1"; + const globalMode = onDiskMode ?? (isPlatform ? "managed" : "your-own"); + + let changed = false; + + for (const [profileName, profileVal] of Object.entries(profiles)) { + const profile = profileVal as Record; + if (!profile || typeof profile !== "object") continue; + + // Skip profiles that already have a provider_connection. + if (profile.provider_connection != null) continue; + + const provider = profile.provider as string | undefined; + if (!provider) continue; + + let connectionName: string; + + if (provider === "ollama") { + connectionName = "ollama-local"; + } else if (globalMode === "managed" && MANAGED_PROVIDERS.has(provider)) { + connectionName = `${provider}-managed`; + } else { + // "your-own" path (or provider not managed-supported): ensure a personal connection exists. + connectionName = `${provider}-personal`; + if (!getConnection(db, connectionName)) { + const credName = credentialKey(provider, "api_key"); + const result = createConnection(db, { + name: connectionName, + provider, + auth: { type: "api_key", credential: credName }, + }); + if (!result.ok) { + log.warn( + { profileName, provider, error: result.error }, + "Failed to create personal connection during backfill; skipping profile", + ); + continue; + } + log.info( + { connectionName, provider, credential: credName }, + "Created personal connection during backfill", + ); + } + } + + profile.provider_connection = connectionName; + profiles[profileName] = profile; + changed = true; + + log.info( + { profileName, connectionName }, + "Backfilled provider_connection for profile", + ); + } + + if (changed) { + llm.profiles = profiles; + raw.llm = llm; + saveRawConfig(raw); + log.info("Saved config.json after provider_connection backfill"); + } +} diff --git a/assistant/src/providers/inference/connections.ts b/assistant/src/providers/inference/connections.ts new file mode 100644 index 00000000000..8b8a86359b7 --- /dev/null +++ b/assistant/src/providers/inference/connections.ts @@ -0,0 +1,231 @@ +import { eq } from "drizzle-orm"; + +import type { DrizzleDb } from "../../memory/db-connection.js"; +import { providerConnections } from "../../memory/schema/inference.js"; +import { clearConnectionProviderCache } from "../registry.js"; +import { + type Auth, + AuthSchema, + type ConnectionProvider, + ConnectionProviderSchema, + type ProviderConnection, + VALID_CONNECTION_PROVIDERS, +} from "./auth.js"; + +// --------------------------------------------------------------------------- +// Read +// --------------------------------------------------------------------------- + +export function listConnections( + db: DrizzleDb, + filter?: { provider?: string }, +): ProviderConnection[] { + const rows = filter?.provider + ? db.select().from(providerConnections).where(eq(providerConnections.provider, filter.provider)).all() + : db.select().from(providerConnections).all(); + + return rows.flatMap((row) => { + const auth = AuthSchema.safeParse(JSON.parse(row.auth)); + if (!auth.success) return []; + const provider = ConnectionProviderSchema.safeParse(row.provider); + if (!provider.success) return []; + return [{ ...row, auth: auth.data, provider: provider.data }]; + }); +} + +export function getConnection( + db: DrizzleDb, + name: string, +): ProviderConnection | null { + const row = db + .select() + .from(providerConnections) + .where(eq(providerConnections.name, name)) + .get(); + + if (!row) return null; + const auth = AuthSchema.safeParse(JSON.parse(row.auth)); + if (!auth.success) return null; + const provider = ConnectionProviderSchema.safeParse(row.provider); + if (!provider.success) return null; + return { ...row, auth: auth.data, provider: provider.data }; +} + +// --------------------------------------------------------------------------- +// Write +// --------------------------------------------------------------------------- + +export type CreateConnectionInput = { + name: string; + provider: string; + auth: Auth; +}; + +export type UpdateConnectionInput = { + auth: Auth; +}; + +export type ConnectionCreateError = + | { code: "already_exists" } + | { code: "invalid_provider"; provider: string } + | { code: "invalid_auth" }; + +export type ConnectionUpdateError = + | { code: "not_found" } + | { code: "invalid_auth" }; + +export type ConnectionDeleteError = + | { code: "not_found" } + | { code: "has_references"; count: number }; + +export function createConnection( + db: DrizzleDb, + input: CreateConnectionInput, +): { ok: true; connection: ProviderConnection } | { ok: false; error: ConnectionCreateError } { + if (!VALID_CONNECTION_PROVIDERS.includes(input.provider as never)) { + return { ok: false, error: { code: "invalid_provider", provider: input.provider } }; + } + // Safe cast: VALID_CONNECTION_PROVIDERS.includes() guards above. + const provider = input.provider as ConnectionProvider; + + const authResult = AuthSchema.safeParse(input.auth); + if (!authResult.success) { + return { ok: false, error: { code: "invalid_auth" } }; + } + + const existing = db + .select({ name: providerConnections.name }) + .from(providerConnections) + .where(eq(providerConnections.name, input.name)) + .get(); + if (existing) { + return { ok: false, error: { code: "already_exists" } }; + } + + const now = Date.now(); + db.insert(providerConnections).values({ + name: input.name, + provider, + auth: JSON.stringify(authResult.data), + createdAt: now, + updatedAt: now, + }).run(); + + // Invalidate per-connection adapter cache so subsequent dispatch + // resolves the freshly-inserted row's auth. + clearConnectionProviderCache(); + + return { + ok: true, + connection: { + name: input.name, + provider, + auth: authResult.data, + createdAt: now, + updatedAt: now, + }, + }; +} + +export function updateConnection( + db: DrizzleDb, + name: string, + input: UpdateConnectionInput, +): { ok: true; connection: ProviderConnection } | { ok: false; error: ConnectionUpdateError } { + const existing = getConnection(db, name); + if (!existing) { + return { ok: false, error: { code: "not_found" } }; + } + + const authResult = AuthSchema.safeParse(input.auth); + if (!authResult.success) { + return { ok: false, error: { code: "invalid_auth" } }; + } + + const now = Date.now(); + db.update(providerConnections) + .set({ auth: JSON.stringify(authResult.data), updatedAt: now }) + .where(eq(providerConnections.name, name)) + .run(); + + // Drop cached adapter built against the previous auth config. + clearConnectionProviderCache(); + + return { + ok: true, + connection: { ...existing, auth: authResult.data, updatedAt: now }, + }; +} + +/** + * Delete a connection. + * + * `force`: when true, delete even if profiles reference it. + * When false, rejects if any profile in the provided profile names list + * references this connection. + */ +export function deleteConnection( + db: DrizzleDb, + name: string, + opts: { force?: boolean; referencingProfiles?: string[] } = {}, +): { ok: true } | { ok: false; error: ConnectionDeleteError } { + const existing = db + .select({ name: providerConnections.name }) + .from(providerConnections) + .where(eq(providerConnections.name, name)) + .get(); + + if (!existing) { + return { ok: false, error: { code: "not_found" } }; + } + + if (!opts.force && opts.referencingProfiles && opts.referencingProfiles.length > 0) { + return { + ok: false, + error: { code: "has_references", count: opts.referencingProfiles.length }, + }; + } + + db.delete(providerConnections).where(eq(providerConnections.name, name)).run(); + + // Evict cached adapter for the deleted connection name. + clearConnectionProviderCache(); + + return { ok: true }; +} + +// --------------------------------------------------------------------------- +// Seed canonical connections (idempotent, used at boot time) +// --------------------------------------------------------------------------- + +const CANONICAL_CONNECTIONS: Array<{ name: string; provider: string; auth: Auth }> = [ + { name: "anthropic-managed", provider: "anthropic", auth: { type: "platform" } }, + { name: "openai-managed", provider: "openai", auth: { type: "platform" } }, + { name: "gemini-managed", provider: "gemini", auth: { type: "platform" } }, + { name: "ollama-local", provider: "ollama", auth: { type: "none" } }, +]; + +/** + * Ensure the four canonical connections exist. Already-existing rows are left + * untouched. Safe to call on every boot. + */ +export function seedCanonicalConnections(db: DrizzleDb): void { + const now = Date.now(); + for (const { name, provider, auth } of CANONICAL_CONNECTIONS) { + const exists = db + .select({ name: providerConnections.name }) + .from(providerConnections) + .where(eq(providerConnections.name, name)) + .get(); + + if (!exists) { + db.insert(providerConnections).values({ + name, + provider, + auth: JSON.stringify(auth), + createdAt: now, + updatedAt: now, + }).run(); + } + } +} diff --git a/assistant/src/providers/inference/resolve-auth.ts b/assistant/src/providers/inference/resolve-auth.ts new file mode 100644 index 00000000000..a3727e6460c --- /dev/null +++ b/assistant/src/providers/inference/resolve-auth.ts @@ -0,0 +1,65 @@ +/** + * Resolves an `Auth` config into a `ResolvedAuth` that adapters consume. + * + * Resolution rules: + * - api_key → fetch credential from vault → inject as bearer header + * - platform → build managed proxy URL and fetch the platform API key + * - none → pass through with no auth headers + * - oauth_subscription / service_account → reject (v2 not yet shipped) + */ + +import { + buildManagedBaseUrl, + resolveManagedProxyContext, +} from "../../providers/managed-proxy/context.js"; +import { getSecureKeyAsync } from "../../security/secure-keys.js"; +import type { Auth, ResolvedAuth } from "./auth.js"; + +export type ResolveAuthError = + | { code: "credential_not_found"; credential: string } + | { code: "platform_unavailable" } + | { code: "not_implemented"; authType: string }; + +export async function resolveAuth( + auth: Auth, + provider: string, +): Promise<{ ok: true; resolved: ResolvedAuth } | { ok: false; error: ResolveAuthError }> { + switch (auth.type) { + case "api_key": { + const value = await getSecureKeyAsync(auth.credential); + if (!value) { + return { ok: false, error: { code: "credential_not_found", credential: auth.credential } }; + } + return { + ok: true, + resolved: { kind: "header", headers: { Authorization: `Bearer ${value}` } }, + }; + } + + case "platform": { + const managedBaseUrl = await buildManagedBaseUrl(provider); + if (!managedBaseUrl) { + return { ok: false, error: { code: "platform_unavailable" } }; + } + const ctx = await resolveManagedProxyContext(); + return { + ok: true, + resolved: { + kind: "header", + headers: { Authorization: `Bearer ${ctx.assistantApiKey}` }, + baseUrl: managedBaseUrl, + }, + }; + } + + case "none": + return { ok: true, resolved: { kind: "none" } }; + + case "oauth_subscription": + case "service_account": + return { + ok: false, + error: { code: "not_implemented", authType: auth.type }, + }; + } +} diff --git a/assistant/src/providers/registry.ts b/assistant/src/providers/registry.ts index a85a94a5160..0a4b5f8bf03 100644 --- a/assistant/src/providers/registry.ts +++ b/assistant/src/providers/registry.ts @@ -1,8 +1,15 @@ import { getProviderKeyAsync } from "../security/secure-keys.js"; import { ProviderNotConfiguredError } from "../util/errors.js"; +import { getLogger } from "../util/logger.js"; import { AnthropicProvider } from "./anthropic/client.js"; import { FireworksProvider } from "./fireworks/client.js"; import { GeminiProvider } from "./gemini/client.js"; +import { createAdapterFromConnection } from "./inference/adapter-factory.js"; +// --------------------------------------------------------------------------- +// Per-connection provider cache (mix-and-match support) +// --------------------------------------------------------------------------- +import type { ProviderConnection } from "./inference/auth.js"; +import { resolveAuth } from "./inference/resolve-auth.js"; import { buildManagedBaseUrl, resolveManagedProxyContext, @@ -16,9 +23,14 @@ import { RetryProvider } from "./retry.js"; import type { Provider } from "./types.js"; import { UsageTrackingProvider } from "./usage-tracking.js"; +const log = getLogger("provider-registry"); + const providers = new Map(); const routingSources = new Map(); +/** Per-connection provider cache, keyed by connection name. */ +const connectionProviders = new Map(); + function registerProvider(name: string, provider: Provider): void { providers.set(name, new UsageTrackingProvider(provider)); } @@ -69,12 +81,6 @@ function resolveModel(config: ProvidersConfig, providerName: string): string { const inferenceProvider = config.llm.default.provider; const inferenceModel = config.llm.default.model; if (inferenceProvider === providerName) { - // If a non-Anthropic provider is selected but the configured model is - // still an Anthropic catalog model (current or previous default), use a - // provider-appropriate fallback instead. Checking the full Anthropic - // catalog rather than only the current default prevents stale persisted - // defaults (e.g. claude-opus-4-6) from being sent to non-Anthropic APIs - // after the catalog default changes. if ( providerName !== "anthropic" && isModelInCatalog("anthropic", inferenceModel) @@ -100,7 +106,6 @@ async function resolveProviderCredentials( source: "user-key" | "managed-proxy"; } | null> { if (mode === "managed") { - // In managed mode, try managed proxy first, then fall back to user key const managedBaseUrl = await buildManagedBaseUrl(providerName); if (managedBaseUrl) { const ctx = await resolveManagedProxyContext(); @@ -110,19 +115,16 @@ async function resolveProviderCredentials( source: "managed-proxy", }; } - // Managed proxy unavailable for this provider; fall back to user key const userKey = await getProviderKeyAsync(providerName); if (userKey) { return { apiKey: userKey, source: "user-key" }; } return null; } - // "your-own" mode: check user key first, then try managed proxy fallback const userKey = await getProviderKeyAsync(providerName); if (userKey) { return { apiKey: userKey, source: "user-key" }; } - // Fall back to managed proxy even in your-own mode (backwards compat) const managedBaseUrl = await buildManagedBaseUrl(providerName); if (managedBaseUrl) { const ctx = await resolveManagedProxyContext(); @@ -140,6 +142,7 @@ export async function initializeProviders( ): Promise { providers.clear(); routingSources.clear(); + connectionProviders.clear(); const streamTimeoutMs = (config.timeouts?.providerStreamTimeoutSec ?? 1800) * 1000; @@ -263,3 +266,73 @@ export async function initializeProviders( routingSources.set("openrouter", "user-key"); } } + +// --------------------------------------------------------------------------- +// Per-connection provider resolution (mix-and-match support) +// --------------------------------------------------------------------------- + +/** + * Resolve a provider instance for a named `provider_connection`. + * + * Results are cached in `connectionProviders` for the lifetime of the + * current `initializeProviders` invocation (cleared on next boot). This + * prevents redundant vault reads for repeated calls to the same connection. + * + * Returns null when: + * - The connection doesn't exist in the DB + * - Auth resolution fails (missing credential, platform unavailable, v2 type) + * - The provider/auth combination yields no usable adapter + */ +export async function resolveProviderFromConnection( + connection: ProviderConnection, + config: ProvidersConfig, +): Promise { + const cached = connectionProviders.get(connection.name); + if (cached) return cached; + + const authResult = await resolveAuth(connection.auth, connection.provider); + if (!authResult.ok) { + const err = authResult.error; + if (err.code === "not_implemented") { + log.warn( + { connectionName: connection.name, authType: err.authType }, + `Auth type '${err.authType}' is not yet implemented (v2). ` + + "Update the connection to use 'api_key', 'platform', or 'none'.", + ); + } else if (err.code === "credential_not_found") { + log.warn( + { connectionName: connection.name, credential: err.credential }, + `Credential '${err.credential}' not found in vault for connection '${connection.name}'.`, + ); + } else { + log.warn( + { connectionName: connection.name }, + `Platform auth unavailable for connection '${connection.name}'.`, + ); + } + return null; + } + + const streamTimeoutMs = + (config.timeouts?.providerStreamTimeoutSec ?? 1800) * 1000; + const useNativeWebSearch = + config.services["web-search"].provider === "inference-provider-native"; + const model = resolveModel(config, connection.provider); + + const provider = createAdapterFromConnection(connection, authResult.resolved, { + model, + streamTimeoutMs, + useNativeWebSearch, + }); + + if (provider) { + connectionProviders.set(connection.name, provider); + } + + return provider; +} + +/** Clear per-connection provider cache (called by initializeProviders on boot). */ +export function clearConnectionProviderCache(): void { + connectionProviders.clear(); +} diff --git a/gateway/src/risk/command-registry/commands/assistant.ts b/gateway/src/risk/command-registry/commands/assistant.ts index d436f33387b..e5710c527e5 100644 --- a/gateway/src/risk/command-registry/commands/assistant.ts +++ b/gateway/src/risk/command-registry/commands/assistant.ts @@ -128,6 +128,13 @@ const ASSISTANT_SUPPORTED_COMMAND_PATHS = [ "image-generation", "image-generation generate", "inference", + "inference providers", + "inference providers connections", + "inference providers connections create", + "inference providers connections delete", + "inference providers connections get", + "inference providers connections list", + "inference providers connections update", "inference send", "inference session", "inference session open", @@ -403,6 +410,31 @@ const riskOverrides: AssistantRiskOverride[] = [ { path: "email send", risk: "high" }, { path: "image-generation generate", risk: "medium" }, { path: "inference send", risk: "medium" }, + { + path: "inference providers connections list", + risk: "low", + reason: "Read-only listing of provider_connection rows", + }, + { + path: "inference providers connections get", + risk: "low", + reason: "Read-only fetch of a single provider_connection row", + }, + { + path: "inference providers connections create", + risk: "medium", + reason: "Inserts a provider_connection row referenced by inference profiles", + }, + { + path: "inference providers connections update", + risk: "medium", + reason: "Mutates provider_connection auth config in place", + }, + { + path: "inference providers connections delete", + risk: "medium", + reason: "Deletes a provider_connection row; refuses unless --force when profiles still reference it", + }, { path: "llm send", risk: "medium" }, { path: "inference session open",