diff --git a/assistant/src/__tests__/config-watcher.test.ts b/assistant/src/__tests__/config-watcher.test.ts index 2db22df6f9c..d5f2de0b5b4 100644 --- a/assistant/src/__tests__/config-watcher.test.ts +++ b/assistant/src/__tests__/config-watcher.test.ts @@ -118,6 +118,12 @@ mock.module("../providers/registry.js", () => ({ listProviders: () => [], getProviderRoutingSource: () => undefined, initializeProviders: () => {}, + // Required by `providers/inference/connections.ts` and + // `providers/connection-resolution.ts`, both loaded transitively when + // ConfigWatcher's deps resolve. Without these, the import chain throws + // "Export named '...' not found in module 'registry.ts'". + clearConnectionProviderCache: () => {}, + resolveProviderFromConnection: async () => null, })); mock.module("../daemon/mcp-reload-service.js", () => ({ diff --git a/assistant/src/config/schemas/llm.ts b/assistant/src/config/schemas/llm.ts index 5cf96e31e46..7de0cbec660 100644 --- a/assistant/src/config/schemas/llm.ts +++ b/assistant/src/config/schemas/llm.ts @@ -283,6 +283,16 @@ const PricingOverrideSchema = z.object({ */ export const LLMConfigBase = z.object({ provider: LLMProvider.default("anthropic"), + /** + * Name of a `provider_connections` row to use for this resolved config. + * Optional and additive: when set, the dispatcher resolves auth from the + * connection (mix-and-match managed/your-own per profile). When unset, + * the dispatcher falls back to the legacy `provider` lookup. + * + * Lives on the merged base type so it flows through `resolveCallSiteConfig` + * naturally — the underlying profile-level field is on `ProfileEntry`. + */ + provider_connection: z.string().min(1).optional(), model: ModelSchema.default("claude-opus-4-7"), maxTokens: MaxTokensSchema.default(64000), effort: EffortEnum.default("max"), diff --git a/assistant/src/daemon/approval-generators.ts b/assistant/src/daemon/approval-generators.ts index b1b6812a75d..cb4135af8ea 100644 --- a/assistant/src/daemon/approval-generators.ts +++ b/assistant/src/daemon/approval-generators.ts @@ -1,6 +1,6 @@ import { loadConfig } from "../config/loader.js"; -import { CallSiteRoutingProvider } from "../providers/call-site-routing.js"; -import { getProvider, listProviders } from "../providers/registry.js"; +import { wrapWithCallSiteRouting } from "../providers/call-site-routing.js"; +import { resolveDefaultProvider } from "../providers/connection-resolution.js"; import type { Provider } from "../providers/types.js"; import { APPROVAL_COPY_MAX_TOKENS, @@ -79,15 +79,16 @@ const VALID_DISPOSITIONS: ReadonlySet = new Set([ export function createApprovalCopyGenerator(): ApprovalCopyGenerator { return async (context, options = {}) => { const config = loadConfig(); - let baseProvider: Provider; - try { - baseProvider = getProvider(config.llm.default.provider); - } catch { - return null; - } + // Connection-aware default-provider resolution. If the default profile + // names a `provider_connection`, route through that connection's auth; + // otherwise fall through to the legacy registry lookup. + const baseProvider: Provider | null = await resolveDefaultProvider(config); + if (!baseProvider) return null; // Wrap so per-call `callSite` can route to an alternative provider // transport when `llm.callSites..provider` overrides the default. - const provider = wrapWithCallSiteRouting(baseProvider); + // The `wrapWithCallSiteRouting` helper threads `config` through so the + // wrapper's per-call resolution is also connection-aware. + const provider = wrapWithCallSiteRouting(baseProvider, config); const fallbackText = options.fallbackText?.trim() || getFallbackMessage(context); @@ -136,12 +137,18 @@ export function createApprovalCopyGenerator(): ApprovalCopyGenerator { export function createApprovalConversationGenerator(): ApprovalConversationGenerator { return async (context) => { const config = loadConfig(); - if (!listProviders().includes(config.llm.default.provider)) { + // Connection-aware default + per-call routing. `resolveDefaultProvider` + // returns null when neither the `provider_connection` path nor the + // legacy registry can produce a Provider, which is the right "no + // provider available" signal here. (We do not pre-gate on + // `listProviders()` because in `your-own` configurations the default + // provider may live entirely behind a `provider_connection` and never + // appear in the legacy registry list.) + const baseProvider = await resolveDefaultProvider(config); + if (!baseProvider) { throw new Error("No provider available for approval conversation"); } - const provider = wrapWithCallSiteRouting( - getProvider(config.llm.default.provider), - ); + const provider = wrapWithCallSiteRouting(baseProvider, config); const pendingDescription = context.pendingApprovals .map((p) => `- Request ${p.requestId}: tool "${p.toolName}"`) @@ -212,19 +219,3 @@ export function createApprovalConversationGenerator(): ApprovalConversationGener return result; }; } - -/** - * Wrap a base Provider so per-call `callSite` metadata can route the actual - * transport to a different provider when `llm.callSites..provider` - * differs from the default. Without this wrapper, only request metadata - * reflects the callSite — the HTTP transport stays bound to the default. - */ -function wrapWithCallSiteRouting(base: Provider): Provider { - return new CallSiteRoutingProvider(base, (name) => { - try { - return getProvider(name); - } catch { - return undefined; - } - }); -} diff --git a/assistant/src/daemon/conversation-store.ts b/assistant/src/daemon/conversation-store.ts index 2928feb34da..81fc66cd1db 100644 --- a/assistant/src/daemon/conversation-store.ts +++ b/assistant/src/daemon/conversation-store.ts @@ -17,9 +17,9 @@ import { getConfig } from "../config/loader.js"; import type { CesClient } from "../credential-execution/client.js"; import { buildSystemPrompt } from "../prompts/system-prompt.js"; -import { CallSiteRoutingProvider } from "../providers/call-site-routing.js"; +import { wrapWithCallSiteRouting } from "../providers/call-site-routing.js"; +import { resolveDefaultProvider } from "../providers/connection-resolution.js"; import { RateLimitProvider } from "../providers/ratelimit.js"; -import { getProvider } from "../providers/registry.js"; import { getSubagentManager } from "../subagent/index.js"; import { getSandboxWorkingDir } from "../util/platform.js"; import { Conversation } from "./conversation.js"; @@ -222,14 +222,18 @@ export async function getOrCreateConversation( const createPromise = (async () => { const config = getConfig(); - let provider = getProvider(config.llm.default.provider); - provider = new CallSiteRoutingProvider(provider, (name) => { - try { - return getProvider(name); - } catch { - return undefined; - } - }); + // Connection-aware default-provider resolution. When the default + // profile names a `provider_connection`, route through that + // connection's auth; otherwise fall through to the legacy registry. + const baseProvider = await resolveDefaultProvider(config); + if (!baseProvider) { + throw new Error( + `Conversation: default provider '${config.llm.default.provider}' is not registered`, + ); + } + // Per-call `callSite` routing layered on top, with connection-awareness + // for alternate profiles (matches the canonical dispatch path). + let provider = wrapWithCallSiteRouting(baseProvider, config); const { rateLimit } = config; if (rateLimit.maxRequestsPerMinute > 0) { provider = new RateLimitProvider( diff --git a/assistant/src/daemon/guardian-action-generators.ts b/assistant/src/daemon/guardian-action-generators.ts index 0a484a2cc97..321f34b5d44 100644 --- a/assistant/src/daemon/guardian-action-generators.ts +++ b/assistant/src/daemon/guardian-action-generators.ts @@ -1,7 +1,6 @@ -import { CallSiteRoutingProvider } from "../providers/call-site-routing.js"; +import { loadConfig } from "../config/loader.js"; +import { wrapWithCallSiteRouting } from "../providers/call-site-routing.js"; import { getConfiguredProvider } from "../providers/provider-send-message.js"; -import { getProvider } from "../providers/registry.js"; -import type { Provider } from "../providers/types.js"; import { buildGuardianActionGenerationPrompt, getGuardianActionFallbackMessage, @@ -32,8 +31,10 @@ export function createGuardianActionCopyGenerator(): GuardianActionCopyGenerator if (!baseProvider) return null; // Wrap so the per-call `callSite` can route to a different provider // transport when `llm.callSites.guardianQuestionCopy.provider` overrides - // the default. Without this, callSite only affects request metadata. - const provider = wrapWithCallSiteRouting(baseProvider); + // the default. Connection-aware: when the resolved profile names a + // `provider_connection`, that connection's auth wins over the legacy + // registry lookup. See `wrapWithCallSiteRouting`. + const provider = wrapWithCallSiteRouting(baseProvider, loadConfig()); const fallbackText = options.fallbackText?.trim() || getGuardianActionFallbackMessage(context); @@ -135,7 +136,7 @@ export function createGuardianFollowUpConversationGenerator(): GuardianFollowUpC if (!baseProvider) { throw new Error("No configured provider available for follow-up conversation"); } - const provider = wrapWithCallSiteRouting(baseProvider); + const provider = wrapWithCallSiteRouting(baseProvider, loadConfig()); const userPrompt = [ `Original question from the voice call: "${context.questionText}"`, @@ -192,19 +193,3 @@ export function createGuardianFollowUpConversationGenerator(): GuardianFollowUpC return result; }; } - -/** - * Wrap a base Provider so per-call `callSite` metadata can route the actual - * transport to a different provider when `llm.callSites..provider` - * differs from the default. Without this wrapper, only request metadata - * reflects the callSite — the HTTP transport stays bound to the default. - */ -function wrapWithCallSiteRouting(base: Provider): Provider { - return new CallSiteRoutingProvider(base, (name) => { - try { - return getProvider(name); - } catch { - return undefined; - } - }); -} diff --git a/assistant/src/home/rollup-producer.ts b/assistant/src/home/rollup-producer.ts index 0c9a184d1be..efae258ac7d 100644 --- a/assistant/src/home/rollup-producer.ts +++ b/assistant/src/home/rollup-producer.ts @@ -36,7 +36,7 @@ */ import { loadConfig } from "../config/loader.js"; -import { getProvider, listProviders } from "../providers/registry.js"; +import { resolveDefaultProvider } from "../providers/connection-resolution.js"; import type { Provider } from "../providers/types.js"; import { getLogger } from "../util/logger.js"; import { @@ -172,7 +172,12 @@ export interface RollupProducerDeps { Awaited> >; loadRecentActions?: () => FeedItem[]; - resolveProvider?: () => Provider | null; + /** + * Test injection point for the default provider. May be sync or async to + * support both legacy stubs and the connection-aware path that loads + * `provider_connection` rows from the DB. + */ + resolveProvider?: () => Provider | null | Promise; } /** @@ -216,8 +221,8 @@ async function runRollupProducerInner( const loadRecentActions = deps.loadRecentActions ?? defaultLoadRecentActions; const provider = deps.resolveProvider - ? deps.resolveProvider() - : resolveDefaultProvider(); + ? await deps.resolveProvider() + : await resolveDefaultProvider(loadConfig()); if (!provider) { return { wroteCount: 0, skippedReason: "no_provider" }; } @@ -292,14 +297,6 @@ async function runRollupProducerInner( return { wroteCount, skippedReason: null }; } -function resolveDefaultProvider(): ReturnType | null { - const config = loadConfig(); - if (!listProviders().includes(config.llm.default.provider)) { - return null; - } - return getProvider(config.llm.default.provider); -} - /** * Default recent-actions loader. Reads the TTL-filtered home feed, * keeps only `action` items, and returns them sorted by `createdAt` diff --git a/assistant/src/providers/__tests__/dispatch-connection-routing.test.ts b/assistant/src/providers/__tests__/dispatch-connection-routing.test.ts new file mode 100644 index 00000000000..8ba628eefa4 --- /dev/null +++ b/assistant/src/providers/__tests__/dispatch-connection-routing.test.ts @@ -0,0 +1,274 @@ +/** + * Cycle-3 gate test — proves that `resolveConfiguredProvider` actually routes + * through `resolveProviderFromConnection` when a profile names a + * `provider_connection`. + * + * Why this exists: cycle-1 and cycle-2 both shipped `resolveProviderFromConnection` + * as dead code (zero call sites), and the cycle-2 "mix-and-match" test only + * validated DB shape — never that the dispatcher actually invoked the + * resolver. This test fails if the wiring regresses, by spying on + * `resolveProviderFromConnection` and asserting: + * + * 1. It was called once per dispatch invocation when the profile has a + * `provider_connection`. + * 2. The connection passed in matches the profile's `provider_connection`. + * 3. The returned `Provider` from each dispatch is the per-connection + * stub (different instances for different connections, regardless of + * shared underlying provider impl name). + * + * Two profiles, same `provider: anthropic`, different `provider_connection`: + * exactly the mix-and-match scenario goal #2 of the design. If the dispatcher + * falls back to `getProvider(name)`, both profiles would route to the same + * Provider instance and this test would catch it. + */ + +import { beforeEach, describe, expect, mock, test } from "bun:test"; + +// --------------------------------------------------------------------------- +// Module mocks (must be declared before the import-under-test). +// --------------------------------------------------------------------------- + +mock.module("../../util/logger.js", () => ({ + getLogger: () => + new Proxy({} as Record, { get: () => () => {} }), +})); + +// Test fixtures for the mocked config loader. +let mockLlmConfig: Record = {}; + +mock.module("../../config/loader.js", () => ({ + getConfig: () => ({ + llm: mockLlmConfig, + services: { inference: { mode: "your-own" } }, + }), +})); + +// Mock the DB getter — we never actually hit SQLite since `getConnection` is +// also mocked. Returning a sentinel keeps the call signature satisfied. +const mockDbSentinel = { __mock: "db" }; +mock.module("../../memory/db-connection.js", () => ({ + getDb: () => mockDbSentinel, +})); + +// Spy storage for the resolver — each test inspects what was passed in. +type Connection = { + name: string; + provider: string; + auth: { type: string; credential?: string }; +}; + +const resolveProviderCalls: Connection[] = []; + +// Each connection name maps to a distinct fake Provider instance. Returning +// distinguishable instances lets the test assert that two profiles with +// different connections route to different providers. +const fakeProviders = new Map(); + +// Connection registry the mocked `getConnection` reads from. +const fakeConnections = new Map(); + +mock.module("../inference/connections.js", () => ({ + getConnection: (_db: unknown, name: string) => + fakeConnections.get(name) ?? null, +})); + +mock.module("../registry.js", () => ({ + // Legacy fallback path — tests that exercise it provide their own entries. + getProvider: (name: string) => { + const p = fakeProviders.get(`legacy:${name}`); + if (!p) throw new Error(`legacy getProvider unknown: ${name}`); + return p; + }, + initializeProviders: async () => {}, + listProviders: () => Array.from(fakeProviders.values()), + // The function under test — wraps the dispatcher's connection-aware path. + resolveProviderFromConnection: async (connection: Connection) => { + resolveProviderCalls.push(connection); + return fakeProviders.get(`conn:${connection.name}`) ?? null; + }, +})); + +// --------------------------------------------------------------------------- +// Imports (after mocks). +// --------------------------------------------------------------------------- + +import { getConfiguredProvider } from "../provider-send-message.js"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function setLlmConfig(c: Record): void { + mockLlmConfig = c; +} + +function registerConnection(c: Connection, providerStub: { name: string; tag: string }): void { + fakeConnections.set(c.name, c); + fakeProviders.set(`conn:${c.name}`, providerStub); +} + +function reset(): void { + resolveProviderCalls.length = 0; + fakeConnections.clear(); + fakeProviders.clear(); + mockLlmConfig = {}; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("dispatch routes through provider_connection (cycle-3 gate)", () => { + beforeEach(reset); + + test("two profiles, same provider, different connections → resolver called twice with the right connection each time", async () => { + // Same underlying provider impl, two distinguishable connection-bound + // Provider stubs. + registerConnection( + { + name: "anthropic-managed", + provider: "anthropic", + auth: { type: "platform" }, + }, + { name: "anthropic", tag: "managed-stub" }, + ); + registerConnection( + { + name: "anthropic-personal", + provider: "anthropic", + auth: { + type: "api_key", + credential: "credential/test/anthropic", + }, + }, + { name: "anthropic", tag: "personal-stub" }, + ); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + "anthropic-managed-profile": { + provider: "anthropic", + provider_connection: "anthropic-managed", + }, + "anthropic-personal-profile": { + provider: "anthropic", + provider_connection: "anthropic-personal", + }, + }, + }); + + const managedResult = await getConfiguredProvider("mainAgent", { + overrideProfile: "anthropic-managed-profile", + }); + const personalResult = await getConfiguredProvider("mainAgent", { + overrideProfile: "anthropic-personal-profile", + }); + + // Hard gate #1: the resolver was called — at all. + expect(resolveProviderCalls.length).toBe(2); + + // Hard gate #2: each call received the right connection by name. + expect(resolveProviderCalls[0].name).toBe("anthropic-managed"); + expect(resolveProviderCalls[1].name).toBe("anthropic-personal"); + + // Hard gate #3: the auth bundle on the connection matches what we'd + // expect at adapter-call time. Different auth types per profile = mix- + // and-match works. + expect(resolveProviderCalls[0].auth.type).toBe("platform"); + expect(resolveProviderCalls[1].auth.type).toBe("api_key"); + expect(resolveProviderCalls[1].auth.credential).toBe( + "credential/test/anthropic", + ); + + // Sanity: dispatch returned non-null for both. + expect(managedResult).not.toBeNull(); + expect(personalResult).not.toBeNull(); + }); + + test("profile WITHOUT provider_connection falls back to legacy registry dispatch", async () => { + fakeProviders.set("legacy:anthropic", { + name: "anthropic", + tag: "legacy-stub", + }); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + "legacy-profile": { + provider: "anthropic", + // no provider_connection — must use getProvider() fallback + }, + }, + }); + + const result = await getConfiguredProvider("mainAgent", { + overrideProfile: "legacy-profile", + }); + + // Resolver must NOT have been called — legacy path only. + expect(resolveProviderCalls.length).toBe(0); + expect(result).not.toBeNull(); + }); + + test("provider_connection set but unknown → falls back to legacy registry dispatch", async () => { + // No connection registered — dispatcher should warn and fall through. + fakeProviders.set("legacy:anthropic", { + name: "anthropic", + tag: "legacy-stub", + }); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + broken: { + provider: "anthropic", + provider_connection: "does-not-exist", + }, + }, + }); + + const result = await getConfiguredProvider("mainAgent", { + overrideProfile: "broken", + }); + + // Resolver was NOT called (lookup failed before reaching it). + expect(resolveProviderCalls.length).toBe(0); + // Legacy path returned a provider — system stays operational. + expect(result).not.toBeNull(); + }); + + test("provider_connection set, connection found, but resolver returns null → falls back to legacy", async () => { + // Connection exists but resolver returns null (e.g., missing credential). + fakeConnections.set("anthropic-broken-personal", { + name: "anthropic-broken-personal", + provider: "anthropic", + auth: { type: "api_key", credential: "credential/missing" }, + }); + // intentionally do NOT register a fakeProviders entry for `conn:anthropic-broken-personal` + fakeProviders.set("legacy:anthropic", { + name: "anthropic", + tag: "legacy-stub", + }); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + "broken-creds": { + provider: "anthropic", + provider_connection: "anthropic-broken-personal", + }, + }, + }); + + const result = await getConfiguredProvider("mainAgent", { + overrideProfile: "broken-creds", + }); + + // Resolver WAS called — but returned null, so we fell back. + expect(resolveProviderCalls.length).toBe(1); + expect(resolveProviderCalls[0].name).toBe("anthropic-broken-personal"); + // Legacy fallback succeeded — system stays operational. + expect(result).not.toBeNull(); + }); +}); diff --git a/assistant/src/providers/__tests__/satellite-connection-routing.test.ts b/assistant/src/providers/__tests__/satellite-connection-routing.test.ts new file mode 100644 index 00000000000..0dec6f6a6ed --- /dev/null +++ b/assistant/src/providers/__tests__/satellite-connection-routing.test.ts @@ -0,0 +1,458 @@ +/** + * Cycle-3 satellite-path gate test. + * + * The dispatcher gate (`dispatch-connection-routing.test.ts`) proves that + * the canonical `getConfiguredProvider()` path honors `provider_connection`. + * That path is used by `provider-send-message.ts` directly. The satellite + * sites — daemon conversation/approval/guardian generators, subagent + * manager, rollup producer — instead build a `CallSiteRoutingProvider` once + * at construction time and reuse it across many `sendMessage` calls, + * routing per-call via `options.config.callSite`. + * + * If `CallSiteRoutingProvider` falls back to `getProvider(name)` when an + * alternate-callSite profile names a `provider_connection`, the satellites + * silently lose connection-awareness for any callSite distinct from the + * default profile. This test proves the wrapper now consults the + * connection-resolution hook before the legacy registry. + * + * Hard gates: + * 1. A call with `callSite: ` whose profile names a connection + * invokes the connection-resolution hook with that name. + * 2. The actual sendMessage transport that runs is the connection-bound + * Provider stub, not the default and not the legacy `getProvider(name)` + * result. + * 3. A call with `callSite: ` whose profile has NO connection still + * falls through to legacy `getProvider(name)`. + * 4. A call with no callSite goes straight to the default provider — no + * hook invocation, no registry lookup. + */ + +import { beforeEach, describe, expect, mock, test } from "bun:test"; + +import type { Provider, ProviderResponse } from "../types.js"; + +// --------------------------------------------------------------------------- +// Module mocks (must be declared before the import-under-test). +// --------------------------------------------------------------------------- + +mock.module("../../util/logger.js", () => ({ + getLogger: () => + new Proxy({} as Record, { get: () => () => {} }), +})); + +let mockLlmConfig: Record = {}; + +mock.module("../../config/loader.js", () => ({ + getConfig: () => ({ + llm: mockLlmConfig, + services: { inference: { mode: "your-own" } }, + }), + loadConfig: () => ({ + llm: mockLlmConfig, + services: { inference: { mode: "your-own" } }, + }), +})); + +const mockDbSentinel = { __mock: "db" }; +mock.module("../../memory/db-connection.js", () => ({ + getDb: () => mockDbSentinel, +})); + +// --------------------------------------------------------------------------- +// Fake provider/connection registries — keep these inspectable from tests. +// --------------------------------------------------------------------------- + +type Connection = { + name: string; + provider: string; + auth: { type: string; credential?: string }; +}; + +// Provider-conforming stub. The `tag` field on the returned response lets +// the test assert which transport actually ran (the connection-bound stub +// vs the legacy registry stub vs the bare default), without leaning on +// reference equality. +interface TaggedResponse extends ProviderResponse { + tag: string; +} +type FakeProviderStub = Provider & { + tag: string; + sendMessage: ( + ...args: Parameters + ) => Promise; +}; + +const fakeConnections = new Map(); +const fakeProviders = new Map(); +const resolveProviderCalls: Connection[] = []; +const sendMessageCalls: { tag: string }[] = []; + +function makeFakeProvider(tag: string, providerName: string): FakeProviderStub { + return { + name: providerName, + tag, + sendMessage: async () => { + sendMessageCalls.push({ tag }); + return { + content: [{ type: "text", text: tag }], + model: "test-model", + usage: { inputTokens: 1, outputTokens: 1 }, + stopReason: "end_turn", + tag, + }; + }, + }; +} + +mock.module("../inference/connections.js", () => ({ + getConnection: (_db: unknown, name: string) => + fakeConnections.get(name) ?? null, +})); + +// Connection names that should make `resolveProviderFromConnection` throw — +// simulates a transient failure inside auth resolution (credential read, +// managed-proxy context lookup) bubbling up from the inner registry call. +const connectionsThatThrowOnResolve = new Set(); + +mock.module("../registry.js", () => ({ + getProvider: (name: string) => { + const p = fakeProviders.get(`legacy:${name}`); + if (!p) throw new Error(`legacy getProvider unknown: ${name}`); + return p; + }, + initializeProviders: async () => {}, + listProviders: () => Array.from(fakeProviders.values()), + resolveProviderFromConnection: async (connection: Connection) => { + resolveProviderCalls.push(connection); + if (connectionsThatThrowOnResolve.has(connection.name)) { + throw new Error(`simulated auth-resolution failure: ${connection.name}`); + } + return fakeProviders.get(`conn:${connection.name}`) ?? null; + }, +})); + +// --------------------------------------------------------------------------- +// Imports (after mocks). +// --------------------------------------------------------------------------- + +import { wrapWithCallSiteRouting } from "../call-site-routing.js"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function setLlmConfig(c: Record): void { + mockLlmConfig = c; +} + +function registerConnection( + c: Connection, + providerStub: FakeProviderStub, +): void { + fakeConnections.set(c.name, c); + fakeProviders.set(`conn:${c.name}`, providerStub); +} + +function reset(): void { + resolveProviderCalls.length = 0; + sendMessageCalls.length = 0; + fakeConnections.clear(); + fakeProviders.clear(); + connectionsThatThrowOnResolve.clear(); + mockLlmConfig = {}; +} + +// ProvidersConfig stub used by the wrapper helper. The connection-resolution +// helper passes it straight to `resolveProviderFromConnection`, which is +// fully mocked above — so a minimal shape is fine. +const providersConfigStub = { + llm: { default: { provider: "anthropic", model: "claude-opus-4-7" } }, + services: { + inference: { mode: "your-own" as const }, + "image-generation": { + mode: "managed" as const, + provider: "openai", + model: "gpt-image-1", + }, + "web-search": { mode: "managed" as const, provider: "brave" }, + }, +}; + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("CallSiteRoutingProvider honors provider_connection (satellite gate)", () => { + beforeEach(reset); + + test("alternate-profile callSite with provider_connection routes through that connection's auth", async () => { + // Default = anthropic, but the rollup callSite is configured to use a + // different profile that names a `provider_connection`. + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + fakeProviders.set("legacy:anthropic", defaultProvider); + + registerConnection( + { + name: "anthropic-managed", + provider: "anthropic", + auth: { type: "platform" }, + }, + makeFakeProvider("connection-managed", "anthropic"), + ); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + "managed-profile": { + provider: "anthropic", + provider_connection: "anthropic-managed", + }, + }, + callSites: { + replySuggestion: { profile: "managed-profile" }, + }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + const response = await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + { config: { callSite: "replySuggestion" } }, + ); + + // Hard gate #1: connection-resolution hook fired with the right name. + expect(resolveProviderCalls.length).toBe(1); + expect(resolveProviderCalls[0].name).toBe("anthropic-managed"); + expect(resolveProviderCalls[0].auth.type).toBe("platform"); + + // Hard gate #2: the actual transport that ran was the connection-bound + // stub, NOT the default and NOT the (mocked) legacy registry result. + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("connection-managed"); + expect((response as unknown as { tag: string }).tag).toBe("connection-managed"); + }); + + test("alternate-profile callSite WITHOUT provider_connection falls through to legacy registry", async () => { + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + fakeProviders.set("legacy:anthropic", defaultProvider); + fakeProviders.set("legacy:openai", makeFakeProvider("legacy-openai", "openai")); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + "openai-profile": { + provider: "openai", + // no provider_connection — must use getProvider("openai") fallback + }, + }, + callSites: { + memoryRetrieval: { profile: "openai-profile" }, + }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + { config: { callSite: "memoryRetrieval" } }, + ); + + // Connection-resolution hook MUST NOT have fired. + expect(resolveProviderCalls.length).toBe(0); + // Legacy registry path produced the openai stub. + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("legacy-openai"); + }); + + test("alternate-profile callSite with unknown provider_connection falls through to legacy", async () => { + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + fakeProviders.set("legacy:anthropic", defaultProvider); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + broken: { + provider: "anthropic", + provider_connection: "does-not-exist", + }, + }, + callSites: { + conversationTitle: { profile: "broken" }, + }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + const response = await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + { config: { callSite: "conversationTitle" } }, + ); + + // Connection lookup attempted (hook called) but returned null. + expect(resolveProviderCalls.length).toBe(0); + // Profile's resolved provider matches default → reused default + // instance (no legacy lookup needed). System stays operational. + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("default-anthropic"); + expect((response as unknown as { tag: string }).tag).toBe("default-anthropic"); + }); + + test("provider/connection mismatch falls through to legacy — no silent misroute", async () => { + // Misconfiguration: profile says provider=openai but provider_connection + // points at an anthropic-flavored row. Without the validation we'd dispatch + // OpenAI traffic to an Anthropic backend (or vice versa). With validation + // we fall through to the legacy `getProvider("openai")` path so the + // request goes where the profile's `provider` field said. + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + fakeProviders.set("legacy:anthropic", defaultProvider); + fakeProviders.set( + "legacy:openai", + makeFakeProvider("legacy-openai", "openai"), + ); + + registerConnection( + { + name: "anthropic-managed", + provider: "anthropic", + auth: { type: "platform" }, + }, + // Note: even though the connection has a stub bound, it should NEVER + // be reached because the connection's provider doesn't match the + // profile's provider. + makeFakeProvider("WRONG-connection-anthropic", "anthropic"), + ); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + mismatched: { + provider: "openai", + // ↑ profile says openai + provider_connection: "anthropic-managed", + // ↑ but connection is anthropic — mismatch + }, + }, + callSites: { + replySuggestion: { profile: "mismatched" }, + }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + { config: { callSite: "replySuggestion" } }, + ); + + // The hook MUST NOT have produced a Provider — the validation check + // returned null without reaching `resolveProviderFromConnection`. + expect(resolveProviderCalls.length).toBe(0); + // Legacy registry path produced the openai stub (matching profile.provider, + // NOT the connection's anthropic). + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("legacy-openai"); + }); + + test("transient auth-resolution failure falls through to legacy — does NOT hard-fail dispatch", async () => { + // Simulates a transient error inside `resolveProviderFromConnection` + // (e.g. a credential read fails, or managed-proxy context lookup + // throws). The wrapper MUST log and fall through to the legacy + // registry path; throwing through to the dispatcher would take + // inference offline for any callsite using a connection-bound profile. + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + fakeProviders.set("legacy:anthropic", defaultProvider); + + registerConnection( + { + name: "flaky-managed", + provider: "anthropic", + auth: { type: "platform" }, + }, + // Provider stub IS registered, but the resolve will throw before + // reaching it. The test asserts the throw is caught. + makeFakeProvider("WOULD-BE-connection", "anthropic"), + ); + connectionsThatThrowOnResolve.add("flaky-managed"); + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + profiles: { + flaky: { + provider: "anthropic", + provider_connection: "flaky-managed", + }, + }, + callSites: { + replySuggestion: { profile: "flaky" }, + }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + // This MUST NOT throw — the resolve failure is contained. + await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + { config: { callSite: "replySuggestion" } }, + ); + + // The hook DID fire (we got past the connection lookup + validation). + expect(resolveProviderCalls.length).toBe(1); + expect(resolveProviderCalls[0].name).toBe("flaky-managed"); + // ...but the throw was caught and we fell through. Profile's + // resolved provider matches default → reused default instance. + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("default-anthropic"); + }); + + test("call without a callSite goes straight to the default provider — no hook, no registry lookup", async () => { + const defaultProvider = makeFakeProvider("default-anthropic", "anthropic"); + + // Note: legacy registry has nothing — if the wrapper tries to consult + // it, the test will throw. Bare-default path proves the short-circuit. + + setLlmConfig({ + default: { provider: "anthropic", model: "claude-opus-4-7" }, + }); + + const wrapped = wrapWithCallSiteRouting( + defaultProvider, + providersConfigStub, + ); + + await wrapped.sendMessage( + [{ role: "user", content: [{ type: "text", text: "hello" }] }], + [], + undefined, + {}, + ); + + expect(resolveProviderCalls.length).toBe(0); + expect(sendMessageCalls.length).toBe(1); + expect(sendMessageCalls[0].tag).toBe("default-anthropic"); + }); +}); diff --git a/assistant/src/providers/call-site-routing.ts b/assistant/src/providers/call-site-routing.ts index d2aaf54e0c3..34534950958 100644 --- a/assistant/src/providers/call-site-routing.ts +++ b/assistant/src/providers/call-site-routing.ts @@ -23,6 +23,9 @@ import { AsyncLocalStorage } from "node:async_hooks"; import { resolveCallSiteConfig } from "../config/llm-resolver.js"; import { getConfig } from "../config/loader.js"; +import { tryResolveProviderForConnectionName } from "./connection-resolution.js"; +import type { ProvidersConfig } from "./registry.js"; +import { getProvider } from "./registry.js"; import type { Message, Provider, @@ -53,6 +56,24 @@ export class CallSiteRoutingProvider implements Provider { constructor( private readonly defaultProvider: Provider, private readonly getProviderByName: (name: string) => Provider | undefined, + /** + * Optional async hook invoked when the resolved profile names a + * `provider_connection`. Returning a Provider routes the call through + * that connection's auth; returning null falls through to the + * legacy `getProviderByName(resolved.provider)` path. + * + * `expectedProvider` is the provider name the resolved profile declared. + * The hook should verify the connection's provider matches and fall + * through (return null) on mismatch. + * + * Optional so existing callers without connection-awareness still + * compile; satellites pass `tryResolveProviderForConnectionName`-bound + * closures to opt in. + */ + private readonly resolveByConnection?: ( + connectionName: string, + expectedProvider: string, + ) => Promise, ) { this.tokenEstimationProvider = defaultProvider.tokenEstimationProvider; } @@ -63,7 +84,7 @@ export class CallSiteRoutingProvider implements Provider { systemPrompt?: string, options?: SendMessageOptions, ): Promise { - const target = this.selectProvider(options); + const target = await this.selectProvider(options); const isRouted = target !== this.defaultProvider; const doSend = async (): Promise => { @@ -91,12 +112,21 @@ export class CallSiteRoutingProvider implements Provider { } /** - * Pick the provider to route this call through. The default provider wins - * unless the per-call `callSite` (layered with any `overrideProfile`) - * resolves to a different provider name and the registry can produce a - * Provider for it. + * Pick the provider to route this call through. + * + * Resolution order: + * 1. No callSite → default provider (legacy short-circuit). + * 2. Resolved profile names a `provider_connection` → async-resolve + * through that connection's auth via `resolveByConnection`. On miss + * we fall through to the next step (don't break inference). + * 3. Resolved profile's `provider` matches the default's name → reuse + * the default provider instance (avoids redundant lookup). + * 4. Otherwise consult `getProviderByName(resolved.provider)`; fall + * back to default if the registry can't produce one. */ - private selectProvider(options?: SendMessageOptions): Provider { + private async selectProvider( + options?: SendMessageOptions, + ): Promise { const callSite = options?.config?.callSite; if (!callSite) return this.defaultProvider; @@ -104,6 +134,15 @@ export class CallSiteRoutingProvider implements Provider { const resolved = resolveCallSiteConfig(callSite, getConfig().llm, { overrideProfile, }); + + if (resolved.provider_connection && this.resolveByConnection) { + const connectionProvider = await this.resolveByConnection( + resolved.provider_connection, + resolved.provider, + ); + if (connectionProvider) return connectionProvider; + } + if (resolved.provider === this.defaultProvider.name) { return this.defaultProvider; } @@ -112,3 +151,35 @@ export class CallSiteRoutingProvider implements Provider { return alternative ?? this.defaultProvider; } } + +/** + * Wrap a base Provider with `CallSiteRoutingProvider` configured to resolve + * alternate-profile routing through the global registry and to route + * `provider_connection` references through the shared connection-resolution + * helper. + * + * `config` is threaded through to the connection lookup so the resolved + * connection's auth can read provider-config metadata (e.g. timeouts, model + * names). + */ +export function wrapWithCallSiteRouting( + base: Provider, + config: ProvidersConfig, +): Provider { + return new CallSiteRoutingProvider( + base, + (name) => { + try { + return getProvider(name); + } catch { + return undefined; + } + }, + (connectionName, expectedProvider) => + tryResolveProviderForConnectionName( + connectionName, + config, + expectedProvider, + ), + ); +} diff --git a/assistant/src/providers/connection-resolution.ts b/assistant/src/providers/connection-resolution.ts new file mode 100644 index 00000000000..28f57f09590 --- /dev/null +++ b/assistant/src/providers/connection-resolution.ts @@ -0,0 +1,134 @@ +/** + * Connection-aware provider resolution helpers. + * + * These wrap `resolveProviderFromConnection` (in `registry.ts`) with the + * DB lookup and lifecycle of a `provider_connection` reference. The + * canonical dispatch path (`provider-send-message.ts`) and each satellite + * site (subagent manager, daemon conversation/approval/guardian generators, + * rollup producer) use these helpers so that connection-awareness behaves + * identically across the codebase. + * + * Resolution policy: + * 1. If the profile names a `provider_connection`, look it up in the DB + * and resolve to a `Provider` with the connection's auth bound. + * 2. On any miss (DB lookup throws, row not found, auth resolution fails) + * log a warning and return null so callers can fall back to legacy + * `getProvider(profile.provider)` dispatch. + * + * The legacy fallback is intentionally retained for one release window — + * cycle-4 cleanup will remove it once we've shipped one release with + * connection-awareness active. + */ + +import { getDb } from "../memory/db-connection.js"; +import { getLogger } from "../util/logger.js"; +import { getConnection } from "./inference/connections.js"; +import type { ProvidersConfig } from "./registry.js"; +import { + getProvider, + resolveProviderFromConnection, +} from "./registry.js"; +import type { Provider } from "./types.js"; + +const log = getLogger("providers/connection-resolution"); + +/** + * Attempt to resolve a Provider through a named `provider_connection`. Returns + * null on any miss (lookup error, row not found, provider mismatch with the + * resolving profile, auth resolution failure) so callers can fall back to the + * legacy `getProvider(name)` path. + * + * `expectedProvider` is the provider name the resolving profile declared. We + * verify the connection row's `provider` field matches before binding — a + * profile that names `provider: "openai"` together with a Anthropic-flavored + * `provider_connection` is a misconfiguration and we fall through rather than + * silently routing the request to the wrong backend. Pass `undefined` to skip + * the check (callers that don't yet know the expected provider). + */ +export async function tryResolveProviderForConnectionName( + connectionName: string, + config: ProvidersConfig, + expectedProvider?: string, +): Promise { + let connection; + try { + connection = getConnection(getDb(), connectionName); + } catch (err) { + log.warn( + { err, connectionName }, + "provider_connection lookup failed — falling back to legacy registry dispatch", + ); + return null; + } + if (!connection) { + log.warn( + { connectionName }, + "provider_connection not found — falling back to legacy registry dispatch", + ); + return null; + } + if (expectedProvider && connection.provider !== expectedProvider) { + log.warn( + { + connectionName, + expectedProvider, + connectionProvider: connection.provider, + }, + "provider_connection provider does not match resolving profile's provider — falling back to legacy registry dispatch to avoid silent misroute", + ); + return null; + } + // `resolveProviderFromConnection` reaches into auth resolution (credential + // reads, managed-proxy context). A transient failure there must not hard- + // fail the dispatcher — log and fall through so the legacy registry path + // can still serve the request. + try { + return await resolveProviderFromConnection(connection, config); + } catch (err) { + log.warn( + { err, connectionName }, + "provider_connection auth resolution failed — falling back to legacy registry dispatch", + ); + return null; + } +} + +/** + * Resolve the connection-aware default provider for the satellite + * construction-time path (subagent manager, conversation store, + * approval/guardian generators, rollup producer). + * + * Reads `config.llm.default.{provider, provider_connection}`. If the default + * profile names a connection, tries connection-aware resolution; otherwise + * (or on miss) falls through to the legacy registry. Returns null if the + * default provider isn't initialised (so callers can early-out gracefully). + */ +export async function resolveDefaultProvider( + config: ProvidersConfig, +): Promise { + const profile = config.llm.default; + // `provider_connection` is read off the runtime config as added by + // `profileConfigFragment`; the typed view in `ProvidersConfig.llm.default` + // doesn't include it yet so cast through. The schema-level type is updated + // in `schemas/llm.ts`; this cast keeps the public `ProvidersConfig` shape + // stable for cycle-3 and is removed when the type alignment lands. + const connectionName = (profile as { provider_connection?: string }) + .provider_connection; + if (connectionName) { + const connectionProvider = await tryResolveProviderForConnectionName( + connectionName, + config, + profile.provider, + ); + if (connectionProvider) return connectionProvider; + } + try { + return getProvider(profile.provider); + } catch (err) { + log.warn( + { err, providerName: profile.provider }, + "default provider not registered — caller should treat as null", + ); + return null; + } +} diff --git a/assistant/src/providers/provider-send-message.ts b/assistant/src/providers/provider-send-message.ts index b48f3b109d0..90cb715212e 100644 --- a/assistant/src/providers/provider-send-message.ts +++ b/assistant/src/providers/provider-send-message.ts @@ -7,7 +7,12 @@ import { resolveCallSiteConfig } from "../config/llm-resolver.js"; import { getConfig } from "../config/loader.js"; import type { LLMCallSite } from "../config/schemas/llm.js"; -import { getProvider, initializeProviders, listProviders } from "./registry.js"; +import { tryResolveProviderForConnectionName } from "./connection-resolution.js"; +import { + getProvider, + initializeProviders, + listProviders, +} from "./registry.js"; import type { ContentBlock, Message, @@ -107,11 +112,31 @@ export async function resolveConfiguredProvider( } } - const inferenceProvider = resolveCallSiteConfig( - callSite, - config.llm, - opts, - ).provider; + const resolved = resolveCallSiteConfig(callSite, config.llm, opts); + const inferenceProvider = resolved.provider; + const connectionName = resolved.provider_connection; + + // Connection-aware path: when the resolved profile names a + // `provider_connection`, route auth through that row's resolver. Falls + // through to the legacy `getProvider(name)` path on any miss so existing + // profiles without `provider_connection` keep working unchanged. + if (connectionName) { + const connectionProvider = await tryResolveProviderForConnectionName( + connectionName, + config, + inferenceProvider, + ); + if (connectionProvider) { + return { + provider: new CallSiteConfiguredProvider( + connectionProvider, + callSite, + opts.overrideProfile, + ), + configuredProviderName: inferenceProvider, + }; + } + } try { const provider = getProvider(inferenceProvider); diff --git a/assistant/src/subagent/manager.ts b/assistant/src/subagent/manager.ts index 45a5619cbcf..13228a5e776 100644 --- a/assistant/src/subagent/manager.ts +++ b/assistant/src/subagent/manager.ts @@ -15,9 +15,9 @@ import { Conversation } from "../daemon/conversation.js"; import { findConversation } from "../daemon/conversation-store.js"; import type { ServerMessage } from "../daemon/message-protocol.js"; import { bootstrapConversation } from "../memory/conversation-bootstrap.js"; -import { CallSiteRoutingProvider } from "../providers/call-site-routing.js"; +import { wrapWithCallSiteRouting } from "../providers/call-site-routing.js"; +import { resolveDefaultProvider } from "../providers/connection-resolution.js"; import { RateLimitProvider } from "../providers/ratelimit.js"; -import { getProvider } from "../providers/registry.js"; import { createAbortReason } from "../util/abort-reasons.js"; import { getLogger } from "../util/logger.js"; import { getSandboxWorkingDir } from "../util/platform.js"; @@ -181,19 +181,20 @@ export class SubagentManager { // ── Build conversation dependencies ───────────────────────────── const appConfig = getConfig(); - let provider = getProvider(appConfig.llm.default.provider); + // Connection-aware default-provider resolution; falls back to legacy + // registry lookup when `llm.default.provider_connection` isn't set or + // resolution misses. Per-call `callSite` routing is layered next. + const baseProvider = await resolveDefaultProvider(appConfig); + if (!baseProvider) { + throw new Error( + `Subagent: default provider '${appConfig.llm.default.provider}' is not registered`, + ); + } // Per-call `options.config.callSite` (e.g. `subagentSpawn`) can resolve - // to a provider name that differs from `llm.default.provider`. Wrap the - // default provider so the actual transport routes correctly per call, - // rather than only forwarding metadata to the default's HTTP client. - // See `providers/call-site-routing.ts`. - provider = new CallSiteRoutingProvider(provider, (name) => { - try { - return getProvider(name); - } catch { - return undefined; - } - }); + // to a profile that differs from `llm.default`. The shared wrapper + // threads `appConfig` through so per-call alternate-profile routing is + // also connection-aware (matches the canonical dispatch path). + let provider = wrapWithCallSiteRouting(baseProvider, appConfig); const { rateLimit } = appConfig; if (rateLimit.maxRequestsPerMinute > 0) { provider = new RateLimitProvider(