diff --git a/apps/desktop/package.json b/apps/desktop/package.json index f494ce9a246..7516b0d2636 100644 --- a/apps/desktop/package.json +++ b/apps/desktop/package.json @@ -73,6 +73,7 @@ "@hono/node-server": "^1.14.1", "@hookform/resolvers": "^5.2.2", "@lezer/highlight": "^1.2.3", + "@mastra/core": "1.25.0", "@parcel/watcher": "^2.5.6", "@pierre/diffs": "1.1.3", "@radix-ui/react-dialog": "^1.1.15", @@ -200,7 +201,7 @@ "lowdb": "^7.0.1", "lowlight": "^3.3.0", "lucide-react": "^0.563.0", - "mastracode": "0.9.2", + "mastracode": "0.14.0", "nanoid": "^5.1.6", "node-addon-api": "^7.1.0", "node-pty": "1.1.0", diff --git a/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming-test-plan.md b/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming-test-plan.md new file mode 100644 index 00000000000..e6f2cffb604 --- /dev/null +++ b/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming-test-plan.md @@ -0,0 +1,86 @@ +# Manual Testing Plan — PR #3517 + +## Prerequisites +- Desktop dev running (`bun dev` from apps/desktop, or full `bun dev` from root) +- At least one project configured with a git repo + +## 1. v1 AI Branch Naming (API key path) + +**Setup**: `ANTHROPIC_API_KEY` or `OPENAI_API_KEY` set in env (or stored via Settings > Models). + +| Step | Expected | +|---|---| +| Open v1 new-workspace modal (Cmd+N) | Modal opens | +| Type a prompt: "fix dropdown alignment bug" | Text entered | +| Submit (Enter or click Create) | Modal closes, pending workspace shows "Generating branch…" briefly | +| Wait for workspace to initialize | Branch name is AI-generated kebab-case (e.g. `fix-dropdown-alignment`), not random words | +| Check worktree | Branch exists locally | + +## 2. v1 AI Branch Naming (no credentials) + +**Setup**: unset `ANTHROPIC_API_KEY` and `OPENAI_API_KEY` from env. No stored API keys in Settings > Models. + +| Step | Expected | +|---|---| +| Create workspace with prompt | Branch name falls back to random friendly name (e.g. `pickle-streetcar`) or prompt-derived slug | +| No error toast | Degradation is silent | + +## 3. v1 Workspace Auto-Rename + +**Setup**: API key available. + +| Step | Expected | +|---|---| +| Create workspace with prompt "refactor auth middleware" | Workspace title updates to AI-generated name (e.g. "Refactor Auth Middleware") after a few seconds | +| If no API key available | Title falls back to prompt text or friendly name | + +## 4. Anthropic OAuth Auto-Refresh (from #3510) + +**Setup**: Anthropic OAuth configured (Claude Max). Requires waiting for token expiry or manual simulation. + +| Step | Expected | +|---|---| +| Sign in to Anthropic via OAuth in Settings > Models | "Active" badge appears | +| Force-expire: edit `~/Library/Application Support/mastracode/auth.json`, set `anthropic.expires` to a past timestamp | — | +| Send a chat message | Chat succeeds silently (token auto-refreshed via `authStorage.getApiKey`). No "Reconnect" prompt. | +| If refresh token is also invalid | Falls to expired state, "Expired" badge + "Reconnect" button appears | +| Check terminal for `[chat-service] Anthropic OAuth refresh failed` | Logged if refresh fails | + +## 5. Settings > Models Page + +| Step | Expected | +|---|---| +| Navigate to Settings > Models | Page loads with Anthropic + OpenAI sections, each with provider icon in header | +| Each provider shows a single card with OAuth row + API Key row | OAuth row: label + badge + action. API Key row: input + contextual buttons | +| **Disconnected state** | "Not connected" badge, primary "Connect" button, no Save/Clear buttons | +| **API key flow**: type key → Save appears → click Save | "API key updated" toast, "Active" badge, "Logout" button appears | +| **API key flow**: click Clear | Key removed, badge reverts to "Not connected" | +| **OAuth flow**: click Connect → complete in browser | "Active" badge, "Logout" button | +| **OAuth flow**: click Logout | Badge reverts, Connect button returns | +| **API key + OAuth**: set API key, then connect OAuth, then disconnect OAuth | API key should survive the OAuth cycle (backup/restore workaround) | +| **OpenAI dialog** auto-opens browser on Connect | No manual "Open browser" step needed | +| **Copy URL** button shows "Copied!" feedback for 2s | — | + +## 6. Production Build + +| Step | Expected | +|---|---| +| `bun run compile:app` (from apps/desktop) | Succeeds. `get-small-model` chunk ~1.2 MB, no 20 MB chunk. | +| `bun run copy:native-modules` | Succeeds | +| `bun run validate:native-runtime` | All checks pass | +| `npx electron dist/main/index.js` | Main process boots (renderer 404 expected in non-packaged mode). No onnxruntime error. | + +## 7. Host-Service Procedure (dormant — future v2) + +Not yet wired to UI. Verify via tRPC playground or direct call if available: + +| Step | Expected | +|---|---| +| Call `workspaceCreation.generateBranchName({ projectId, prompt: "fix auth bug" })` | Returns `{ branchName: "fix-auth-bug" }` or similar (requires API key in host-service env) | +| Call with empty prompt | Returns `{ branchName: null }` | +| Call with no API key in env | Returns `{ branchName: null }` (graceful fallback) | + +## Known Regressions (documented, accepted) + +- **OAuth-only users** (Claude Max / OpenAI Codex without stored API key) get random branch names and prompt-derived workspace titles for small-model tasks. Main chat retains full OAuth. +- **Upstream dependency**: API key storage slot collision with OAuth is worked around via backup/restore. Proper fix tracked at mastra-ai/mastra#15483. diff --git a/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming.md b/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming.md new file mode 100644 index 00000000000..5a88bb50813 --- /dev/null +++ b/apps/desktop/plans/done/20260415-v2-host-service-ai-branch-naming.md @@ -0,0 +1,168 @@ +# V2 Workspace Modal — Host-Service AI Branch Naming + +Port v1's AI branch-name generation into v2's workspace modal, routed through host-service. Approach: **use upstream `mastracode`'s `resolveModel`** via a lightweight `createMastraCode({ disableMcp: true, disableHooks: true })` singleton. Delete our small-model abstraction; keep OAuth parity (Claude Max + Codex) because mastracode handles it internally. + +## Completed + +- ✅ Bumped `mastracode` 0.9.2 → **0.14.0** (+ transitive `@mastra/core` 1.16 → 1.25). Typecheck + tests green. Removed `minimumReleaseAge` from `bunfig.toml`. + +## Target architecture + +``` +v2 useSubmitWorkspace + └─> client.workspaceCreation.generateBranchName.mutate({ projectId, prompt }) + └─> generateBranchNameFromPrompt(...) [host-service] + └─> getSmallModel() [shared helper] + └─> resolveModel(modelId) from mastracode + (full auth: API-key + keychain + OAuth middleware) +``` + +Desktop v1's existing `ai-branch-name.ts` migrates to the same `getSmallModel` helper — single implementation, two consumers. + +## Shared helper + +`packages/chat/src/server/shared/small-model/get-small-model.ts`: + +```ts +import { createAuthStorage, createMastraCode } from "mastracode"; +import type { MastraLanguageModel } from "@mastra/core/llm"; + +const ANTHROPIC_SMALL = "anthropic/claude-haiku-4-5-20251001"; +const OPENAI_SMALL = "openai/gpt-4o-mini"; + +type Resolver = Awaited>["resolveModel"]; +let initPromise: Promise | null = null; + +function getResolver(): Promise { + if (!initPromise) { + initPromise = createMastraCode({ disableMcp: true, disableHooks: true }) + .then((r) => r.resolveModel); + } + return initPromise; +} + +function pickSmallModelId(): string | null { + const auth = createAuthStorage(); + auth.reload(); + if (auth.has("anthropic")) return ANTHROPIC_SMALL; + if (auth.has("openai")) return OPENAI_SMALL; + return null; +} + +export async function getSmallModel(): Promise { + const modelId = pickSmallModelId(); + if (!modelId) return null; + const resolveModel = await getResolver(); + return resolveModel(modelId) as MastraLanguageModel; +} +``` + +Module-level promise caches the mastracode init (one-time cost per process). Credential check is per-call (cheap, in-memory). + +## Code-removal budget + +| File | LOC | Fate | +|---|---|---| +| `apps/desktop/src/lib/ai/call-small-model.ts` | 184 | delete | +| `apps/desktop/src/lib/ai/call-small-model.test.ts` | 399 | delete | +| `apps/desktop/src/lib/ai/provider-diagnostics.ts` | 89 | delete if no other consumer | +| `packages/chat/src/server/desktop/small-model/small-model.ts` | 146 | delete | +| `packages/chat/src/server/desktop/small-model/small-model.test.ts` | 391 | delete | +| `packages/chat/src/server/desktop/title-generation/title-generation.ts` | 99 | trim (~50, drop streaming variant) | +| `packages/chat/src/server/desktop/auth/anthropic/anthropic.ts` | 232 | trim (~50, keep OAuth login helpers chat-service uses) | +| `packages/chat/src/server/desktop/auth/openai/openai.ts` | 99 | trim (~30) | +| `apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts` | 117 | rewrite → ~60 | +| New `shared/small-model/get-small-model.ts` | — | +50 | + +Net: **~1200 lines removed**. + +--- + +## Step 1 — Shared helper + migrate v1 branch naming + +### Actionable tasks +1. Create `packages/chat/src/server/shared/small-model/{get-small-model.ts, index.ts}` with the helper above. +2. Update `packages/chat/src/server/desktop/index.ts` barrel if needed; new helper lives in `shared/` and is imported directly from `@superset/chat/server/shared/small-model` — no re-export from desktop. +3. Rewrite `apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts`: + - Replace `callSmallModel` + provider branching with `getSmallModel()` + `generateText({ model, system, prompt })`. + - Keep `BRANCH_NAME_INSTRUCTIONS`, `resolveConflict`, `sanitizeBranchNameWithMaxLength`. +4. Grep for `callSmallModel`, `SmallModelProvider`, `getDefaultSmallModelProviders`, `generateTitleFromMessageWithStreamingModel`: + - Rewrite each consumer to `getSmallModel` + `generateText` (or Mastra Agent if the caller wants tracing). +5. Delete: + - `apps/desktop/src/lib/ai/call-small-model.ts` + test. + - `packages/chat/src/server/desktop/small-model/small-model.ts` + test + `index.ts`. + - `generateTitleFromMessageWithStreamingModel` from `title-generation.ts`. +6. `apps/desktop/src/lib/ai/provider-diagnostics.ts` — grep for consumers; delete if only `call-small-model.ts` uses it. Otherwise leave. +7. Audit `auth/anthropic` and `auth/openai`: keep exports chat-service uses for OAuth login UI; delete any credential-resolution helpers used only for small-model. +8. Run `bun run typecheck` + focused tests (chat-service, ai-branch-name). Fix breaks. +9. Smoke: launch desktop, create v1 workspace with a prompt, verify AI branch naming still works (both API key and OAuth paths). + +### Risks (step 1) +- **mastracode init side effects**: `createMastraCode` with disabled MCP/hooks still initializes storage, auto-detects project, etc. Confirm startup stays under ~200ms and doesn't create unwanted files. If it tries to touch a DB/libsql, pass an explicit `storage` config. +- **Second init conflict**: chat service already calls `createMastraCode` for its runtime. Running a second one for small-model might duplicate auth-storage singletons or compete for files. Mitigation: verify `createMastraCode` is side-effect-safe when called twice; if not, share the existing chat runtime's resolver. +- **Credential regression**: `authStorage.has("anthropic")` must cover all the "credential present" cases our current `getAnthropicCredentialsFromAnySource` covers (env vars, stored API keys, OAuth). Audit before replacing. + +--- + +## Step 2 — Host-service procedure + +### Actionable tasks +1. Port `sanitizeBranchNameWithMaxLength` (`apps/desktop/src/shared/utils/branch.ts`) and `resolveBranchPrefix` (`apps/desktop/src/lib/trpc/routers/workspaces/utils/branch-prefix.ts`) into `packages/host-service/src/trpc/router/workspace-creation/utils/`. +2. Create `packages/host-service/src/trpc/router/workspace-creation/utils/ai-branch-name.ts` — same helper as desktop's rewritten v1, imports `getSmallModel` from `@superset/chat/server/shared/small-model`. +3. Add to `workspace-creation.ts`: + ```ts + generateBranchName: publicProcedure + .input(z.object({ projectId: z.string(), prompt: z.string() })) + .mutation(async ({ input }) => { + const trimmed = input.prompt.trim(); + if (!trimmed) return { branchName: null }; + const project = /* existing project lookup */; + const existingBranches = /* existing branch listing */; + const prefix = await resolveBranchPrefix(project, existingBranches); + const branchName = await generateBranchNameFromPrompt(trimmed, existingBranches, prefix); + return { branchName }; + }), + ``` +4. Delete `packages/host-service/src/providers/model-providers/LocalModelProvider/utils/resolveAnthropicCredential.ts` + `resolveOpenAICredential.ts` if unused after step (LocalModelProvider no longer needs them since auth flows through mastracode). +5. Run typecheck + host-service tests. + +--- + +## Step 3 — Wire v2 + +### Actionable tasks +1. Update `apps/desktop/src/renderer/routes/_authenticated/components/DashboardNewWorkspaceModal/components/DashboardNewWorkspaceForm/PromptGroup/hooks/useSubmitWorkspace/useSubmitWorkspace.ts`: + - Compute `willGenerateAIName = !draft.branchNameEdited && !!trimmedPrompt && !draft.linkedPR`. + - Fallback via `resolveNames(draft)` (unchanged). + - Insert pending row with status `"generating-branch"` if `willGenerateAIName`. + - Close + navigate (unchanged). + - If `willGenerateAIName`, race `client.workspaceCreation.generateBranchName.mutate(...)` vs 30s timeout: + - success → update pending row `branchName` + status `"creating"`. + - auth error → toast + abort + remove pending row. + - other/timeout → toast `"Using random branch name..."`, keep fallback name. + - Call `client.workspaceCreation.create(...)` with resolved `branchName`. +2. Add `"generating-branch"` to `pendingWorkspaces` status union (`packages/local-db/src/schema/schema.ts`). Drizzle migration. +3. Update pending page UI (`apps/desktop/src/renderer/routes/_authenticated/_dashboard/pending/$pendingId/page.tsx`) to render "Naming your branch…" for that status. + +--- + +## Effort + +| Step | Effort | +|---|---| +| 0. mastracode upgrade | ✅ done | +| 1. Shared helper + v1 migration + deletions | 2–3 hrs | +| 2. Host-service procedure | 1–1.5 hrs | +| 3. v2 wiring + pending UI | 1–2 hrs | +| **Remaining** | **~4–6.5 hrs** | + +## Risks + +- **mastracode init side effects** at singleton init (see step 1). +- **Remote host-service API-key availability**: remote hosts need `ANTHROPIC_API_KEY` / `OPENAI_API_KEY` set; otherwise v2 on remote hosts falls back to random-name. Document. +- **OAuth parity in host-service**: host-service can't do an interactive OAuth flow. `createAuthStorage().loadStoredApiKeysIntoEnv(...)` loads stored API keys but NOT OAuth tokens into env. For host-service, OAuth-only users get random names. +- **Diagnostics UI**: removing `provider-diagnostics.ts` removes mid-call `reportProviderIssue` signals. Audit settings UI for providers; they may source signals from chat-service regardless. + +## Out of scope +- Live/debounced ghost suggestion in v2 branch-name input. +- Retiring v1's desktop-tRPC `generateBranchName` procedure (it becomes a proxy over the shared helper; deleting it is a follow-up). diff --git a/apps/desktop/plans/done/20260417-fix-api-key-storage-slot.md b/apps/desktop/plans/done/20260417-fix-api-key-storage-slot.md new file mode 100644 index 00000000000..d31b76a60bb --- /dev/null +++ b/apps/desktop/plans/done/20260417-fix-api-key-storage-slot.md @@ -0,0 +1,57 @@ +# Fix: API keys overwritten by OAuth connect/disconnect cycle + +## Problem + +Settings > Models "API key" field writes to the same auth.json slot as OAuth. When a user: +1. Saves an API key → `authStorage.set("anthropic", { type: "api_key", key: "sk-..." })` +2. Connects OAuth → `authStorage.login("anthropic", ...)` overwrites with `{ type: "oauth", ... }` +3. Disconnects OAuth → `authStorage.remove("anthropic")` deletes everything + +The API key is lost. The model picker shows "disabled" even though the user saved a key. + +Chat still works because `createMastraCode`'s model resolution reads from env vars / external config independently of this status check. + +## Root cause + +`setApiKeyForProvider` uses `authStorage.set(providerId, credential)` which writes to the main provider slot. OAuth also writes to the same slot. They collide. + +mastracode's `AuthStorage` has **two separate storage mechanisms**: +- `set(providerId, credential)` / `get(providerId)` → main slot (`"anthropic"` in auth.json) +- `setStoredApiKey(providerId, key)` / `getStoredApiKey(providerId)` → dedicated API key slot (`"apikey:anthropic"` in auth.json) + +We're using the wrong one for API keys. + +## Fix + +### `auth-storage-utils.ts` + +**`setApiKeyForProvider`**: switch from `authStorage.set()` to `authStorage.setStoredApiKey()`. + +**`clearApiKeyForProvider`**: clear the `apikey:` slot. Use `authStorage.set("apikey:", ...)` with a removal, or check `hasStoredApiKey` and handle accordingly. Since mastracode doesn't expose `removeStoredApiKey`, use `authStorage.remove("apikey:")`. + +**`resolveAuthMethodForProvider`**: after checking the main slot, also check `authStorage.hasStoredApiKey(providerId)` as a fallback → return `"api_key"`. + +### `chat-service.ts` + +No changes needed — `getAnthropicAuthStatus` and `getOpenAIAuthStatus` already delegate to `resolveAuthMethodForProvider` which will now find stored API keys. + +The `setStoredAnthropicApiKeyFromEnvVariables` helper in `disconnectAnthropicOAuth` should also use `setStoredApiKey` for consistency, but it's less critical since it reads from the env config file. + +## Behavior after fix + +| Action | `"anthropic"` (main) | `"apikey:anthropic"` (dedicated) | +|---|---|---| +| Save API key (Settings) | unchanged | written | +| Connect OAuth | overwritten with OAuth | survives | +| Disconnect OAuth | removed | survives | +| Auth status check | reads both | ← | + +## Side effect: small-model tasks + +`getSmallModel` reads `apikey:anthropic` from auth.json directly. Currently, API keys saved via Settings go to the main `"anthropic"` slot, so `getSmallModel` doesn't find them. After this fix, saved API keys land in `apikey:anthropic` where `getSmallModel` already looks → branch naming works for Settings-saved keys without any additional change. + +## Scope + +- `packages/chat/src/server/desktop/chat-service/auth-storage-utils.ts` (~15 LOC changed) +- `packages/chat/src/server/desktop/chat-service/chat-service.ts` — `setStoredAnthropicApiKeyFromEnvVariables` updated for consistency (~2 LOC) +- Tests in `chat-service.test.ts` if any mock `setApiKeyForProvider` behavior diff --git a/apps/desktop/runtime-dependencies.ts b/apps/desktop/runtime-dependencies.ts index 18d5e5b2e4b..8b02f1f71ec 100644 --- a/apps/desktop/runtime-dependencies.ts +++ b/apps/desktop/runtime-dependencies.ts @@ -108,7 +108,7 @@ export const mainExternalizedDependencies = [ // mastracode transitively loads @mastra/fastembed → onnxruntime-node, whose // native binding is loaded via a dynamic `require` that @rollup/plugin-commonjs // can't resolve at bundle time. Externalizing lets Node handle the require at - // runtime from node_modules. + // runtime from node_modules. Also keeps the bundle size sane (~20 MB chunk). "mastracode", ]; diff --git a/apps/desktop/src/lib/ai/call-small-model.test.ts b/apps/desktop/src/lib/ai/call-small-model.test.ts deleted file mode 100644 index ff133f56a18..00000000000 --- a/apps/desktop/src/lib/ai/call-small-model.test.ts +++ /dev/null @@ -1,399 +0,0 @@ -import { beforeEach, describe, expect, it, mock } from "bun:test"; -import type { SmallModelProvider } from "@superset/chat/server/desktop"; - -const getDefaultSmallModelProvidersMock = mock((): SmallModelProvider[] => []); - -mock.module("@superset/chat/server/desktop", () => ({ - getDefaultSmallModelProviders: getDefaultSmallModelProvidersMock, - generateTitleFromMessage: mock(async () => null), - generateTitleFromMessageWithStreamingModel: mock(async () => null), -})); - -const { callSmallModel } = await import("./call-small-model"); - -describe("callSmallModel", () => { - beforeEach(() => { - getDefaultSmallModelProvidersMock.mockReset(); - getDefaultSmallModelProvidersMock.mockReturnValue([]); - }); - - it("skips unsupported credentials and falls through to the next working provider", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "oauth-token", - kind: "oauth", - source: "auth-storage", - }), - isSupported: () => ({ - supported: false, - reason: "unsupported oauth", - }), - createModel: () => "openai-model", - }, - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => ({ - apiKey: "anthropic-token", - kind: "oauth", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "anthropic-model", - }, - ], - invoke: async ({ providerId, model }) => - providerId === "anthropic" && model === "anthropic-model" - ? "generated title" - : null, - }); - - expect(result).toBe("generated title"); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "oauth", - credentialSource: "auth-storage", - issue: { - code: "unsupported_credentials", - capability: "small_model_tasks", - remediation: "add_api_key", - message: "unsupported oauth", - }, - outcome: "unsupported-credentials", - reason: "unsupported oauth", - }, - { - providerId: "anthropic", - providerName: "Anthropic", - credentialKind: "oauth", - credentialSource: "auth-storage", - outcome: "succeeded", - }, - ]); - }); - - it("allows OpenAI OAuth credentials on the small-model path", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "oauth-token", - kind: "oauth", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "openai-model", - }, - ], - invoke: async ({ providerId, model }) => - providerId === "openai" && model === "openai-model" - ? "generated title" - : null, - }); - - expect(result).toBe("generated title"); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "oauth", - credentialSource: "auth-storage", - outcome: "succeeded", - }, - ]); - }); - - it("treats empty-string results as successful model output", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "oauth-token", - kind: "oauth", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "openai-model", - }, - ], - invoke: async () => "", - }); - - expect(result).toBe(""); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "oauth", - credentialSource: "auth-storage", - outcome: "succeeded", - }, - ]); - }); - - it("classifies missing OpenAI scopes as a canonical provider issue", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "oauth-token", - kind: "oauth", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "openai-model", - }, - ], - invoke: async () => { - throw new Error( - "You have insufficient permissions for this operation. Missing scopes: api.responses.write.", - ); - }, - }); - - expect(result).toBeNull(); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "oauth", - credentialSource: "auth-storage", - issue: { - code: "missing_scope", - capability: "small_model_tasks", - remediation: "check_permissions", - scope: "api.responses.write", - message: "OpenAI needs permission api.responses.write", - }, - outcome: "failed", - reason: - "You have insufficient permissions for this operation. Missing scopes: api.responses.write.", - }, - ]); - }); - - it("returns null after exhausting providers", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => null, - isSupported: () => ({ supported: true }), - createModel: () => "unused", - }, - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "api-key", - kind: "apiKey", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "openai-model", - }, - ], - invoke: async () => null, - }); - - expect(result).toBeNull(); - expect(attempts).toEqual([ - { - providerId: "anthropic", - providerName: "Anthropic", - outcome: "missing-credentials", - }, - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "apiKey", - credentialSource: "auth-storage", - outcome: "empty-result", - }, - ]); - }); - - it("skips expired oauth credentials before attempting the request", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => ({ - apiKey: "expired-oauth", - kind: "oauth", - source: "config", - expiresAt: Date.now() - 1_000, - }), - isSupported: () => ({ supported: true }), - createModel: () => "anthropic-model", - }, - ], - invoke: async () => "should-not-run", - }); - - expect(result).toBeNull(); - expect(attempts).toEqual([ - { - providerId: "anthropic", - providerName: "Anthropic", - credentialKind: "oauth", - credentialSource: "config", - issue: { - code: "expired", - capability: "small_model_tasks", - remediation: "reconnect", - message: "Anthropic session expired", - }, - outcome: "expired-credentials", - reason: "Anthropic session expired", - }, - ]); - }); - - it("continues after a provider throws and returns the next successful result", async () => { - const { result, attempts } = await callSmallModel({ - providers: [ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "api-key", - kind: "apiKey", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => { - throw new Error("provider unavailable"); - }, - }, - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => ({ - apiKey: "anthropic-key", - kind: "apiKey", - source: "config", - }), - isSupported: () => ({ supported: true }), - createModel: () => "anthropic-model", - }, - ], - invoke: async ({ providerId, model }) => - providerId === "anthropic" && model === "anthropic-model" - ? "fallback title" - : null, - }); - - expect(result).toBe("fallback title"); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "apiKey", - credentialSource: "auth-storage", - issue: { - code: "unknown_error", - capability: "small_model_tasks", - remediation: "try_again", - message: "OpenAI could not complete this request", - }, - outcome: "failed", - reason: "provider unavailable", - }, - { - providerId: "anthropic", - providerName: "Anthropic", - credentialKind: "apiKey", - credentialSource: "config", - outcome: "succeeded", - }, - ]); - }); - - it("respects providerOrder when a caller prefers one provider first", async () => { - const visited: string[] = []; - - await callSmallModel({ - providers: [ - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => ({ - apiKey: "anthropic-key", - kind: "apiKey", - source: "config", - }), - isSupported: () => ({ supported: true }), - createModel: () => "anthropic-model", - }, - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "openai-key", - kind: "apiKey", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "openai-model", - }, - ], - providerOrder: ["openai", "anthropic"], - invoke: async ({ providerId }) => { - visited.push(providerId); - return "title"; - }, - }); - - expect(visited).toEqual(["openai"]); - }); - - it("uses shared default providers when none are supplied", async () => { - getDefaultSmallModelProvidersMock.mockReturnValue([ - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => ({ - apiKey: "api-key", - kind: "apiKey", - source: "auth-storage", - }), - isSupported: () => ({ supported: true }), - createModel: () => "shared-openai-model", - }, - ]); - - const { result, attempts } = await callSmallModel({ - invoke: async ({ providerId, model }) => - providerId === "openai" && model === "shared-openai-model" - ? "title" - : null, - }); - - expect(result).toBe("title"); - expect(getDefaultSmallModelProvidersMock).toHaveBeenCalledTimes(1); - expect(attempts).toEqual([ - { - providerId: "openai", - providerName: "OpenAI", - credentialKind: "apiKey", - credentialSource: "auth-storage", - outcome: "succeeded", - }, - ]); - }); -}); diff --git a/apps/desktop/src/lib/ai/call-small-model.ts b/apps/desktop/src/lib/ai/call-small-model.ts index 429a7a0147e..369c548ee28 100644 --- a/apps/desktop/src/lib/ai/call-small-model.ts +++ b/apps/desktop/src/lib/ai/call-small-model.ts @@ -1,24 +1,33 @@ -import { - getDefaultSmallModelProviders, - type SmallModelCredential, - type SmallModelProvider, -} from "@superset/chat/server/desktop"; -import { - classifyProviderIssue, - type ProviderId, - type ProviderIssue, -} from "shared/ai/provider-status"; -import { - clearProviderIssue, - reportProviderIssue, -} from "./provider-diagnostics"; +// FORK NOTE: upstream #3517 removed fork's SmallModelProviders array +// and the provider-diagnostics store. Fork code (enhance-text.ts, +// git-operations.ts) still calls callSmallModel({ invoke }) expecting +// { result, attempts } with per-provider fallback. This shim restores +// that behavior on top of getSmallModelCandidates() (a fork-maintained +// replacement that returns the full priority list with OAuth / API key +// / proxy AUTH_TOKEN correctly wired via getAnthropicProviderOptions). +// +// Trade-offs vs. the pre-#3517 fork: +// - ProviderIssue reporting collapsed to generic `failed` — upstream +// removed the diagnostic classifiers when it dropped +// provider-diagnostics, and fork no longer surfaces them anywhere +// except describeEnhanceFailure's reason string. +// - Credential resolution happens synchronously (mastracode token +// refresh is not awaited in the candidate list). If an OAuth access +// token is actually expired, the next candidate in the priority +// chain is tried. +import { getSmallModelCandidates } from "@superset/chat/server/shared"; +import type { ProviderId, ProviderIssue } from "shared/ai/provider-status"; -type SmallModelProviderId = ProviderId; +export type SmallModelCredentialKind = "api_key" | "oauth" | "env"; +export interface SmallModelCredential { + kind: SmallModelCredentialKind; + source?: string; +} export interface SmallModelAttempt { - providerId: SmallModelProviderId; + providerId: ProviderId; providerName: string; - credentialKind?: SmallModelCredential["kind"]; + credentialKind?: SmallModelCredentialKind; credentialSource?: string; issue?: ProviderIssue; outcome: @@ -32,153 +41,123 @@ export interface SmallModelAttempt { } export interface SmallModelInvocationContext { - providerId: SmallModelProviderId; + providerId: ProviderId; providerName: string; model: unknown; credentials: SmallModelCredential; } -function orderProviders( - providers: SmallModelProvider[], - providerOrder?: SmallModelProviderId[], -): SmallModelProvider[] { - if (!providerOrder || providerOrder.length === 0) { - return providers; - } - - const rank = new Map( - providerOrder.map((providerId, index) => [providerId, index]), - ); - return [...providers].sort((left, right) => { - const leftRank = rank.get(left.id) ?? Number.MAX_SAFE_INTEGER; - const rightRank = rank.get(right.id) ?? Number.MAX_SAFE_INTEGER; - return leftRank - rightRank; - }); +function toShimCredentialKind( + kind: "apiKey" | "oauth", +): SmallModelCredentialKind { + return kind === "oauth" ? "oauth" : "api_key"; } export async function callSmallModel({ invoke, - providers = getDefaultSmallModelProviders(), providerOrder, }: { invoke: ( context: SmallModelInvocationContext, ) => Promise; - providers?: SmallModelProvider[]; - providerOrder?: SmallModelProviderId[]; + providerOrder?: ProviderId[]; }): Promise<{ result: TResult | null; attempts: SmallModelAttempt[]; }> { + const allCandidates = getSmallModelCandidates(); + + const ordered = providerOrder + ? [...allCandidates].sort((a, b) => { + const ai = providerOrder.indexOf(a.providerId); + const bi = providerOrder.indexOf(b.providerId); + return ( + (ai === -1 ? Number.MAX_SAFE_INTEGER : ai) - + (bi === -1 ? Number.MAX_SAFE_INTEGER : bi) + ); + }) + : allCandidates; + const attempts: SmallModelAttempt[] = []; - for (const provider of orderProviders(providers, providerOrder)) { - const credentials = provider.resolveCredentials(); - if (!credentials) { - attempts.push({ - providerId: provider.id, - providerName: provider.name, - outcome: "missing-credentials", - }); - clearProviderIssue(provider.id, "small_model_tasks"); - continue; - } - if ( - credentials.kind === "oauth" && - typeof credentials.expiresAt === "number" && - credentials.expiresAt <= Date.now() - ) { - const issue: ProviderIssue = { - code: "expired", - capability: "small_model_tasks", - remediation: "reconnect", - message: `${provider.name} session expired`, - }; - attempts.push({ - providerId: provider.id, - providerName: provider.name, - credentialKind: credentials.kind, - credentialSource: credentials.source, - issue, - outcome: "expired-credentials", - reason: issue.message, - }); - reportProviderIssue(provider.id, issue); - continue; - } + if (ordered.length === 0) { + // No credentials at all for either provider. Fabricate two + // missing-credentials attempts so describeEnhanceFailure's + // "every attempt is missing-credentials" branch triggers the + // correct "アカウントが接続されていません" message. + return { + result: null, + attempts: [ + { + providerId: "anthropic", + providerName: "Anthropic", + outcome: "missing-credentials", + }, + { + providerId: "openai", + providerName: "OpenAI", + outcome: "missing-credentials", + }, + ], + }; + } - const support = provider.isSupported(credentials); - if (!support.supported) { - const issue: ProviderIssue = { - code: "unsupported_credentials", - capability: "small_model_tasks", - remediation: "add_api_key", - message: - support.reason ?? - `${provider.name} credentials are not supported for this request`, - }; + for (const candidate of ordered) { + const credentials: SmallModelCredential = { + kind: toShimCredentialKind(candidate.credentialKind), + source: candidate.credentialSource, + }; + let model: unknown; + try { + model = candidate.createModel(); + } catch (error) { attempts.push({ - providerId: provider.id, - providerName: provider.name, + providerId: candidate.providerId, + providerName: candidate.providerName, credentialKind: credentials.kind, - credentialSource: credentials.source, - issue, - outcome: "unsupported-credentials", - reason: support.reason, + credentialSource: candidate.credentialSource, + outcome: "failed", + reason: error instanceof Error ? error.message : String(error), }); - reportProviderIssue(provider.id, issue); continue; } try { - const model = await provider.createModel(credentials); const result = await invoke({ - providerId: provider.id, - providerName: provider.name, + providerId: candidate.providerId, + providerName: candidate.providerName, model, credentials, }); - if (result != null) { + if (result === null || result === undefined) { attempts.push({ - providerId: provider.id, - providerName: provider.name, + providerId: candidate.providerId, + providerName: candidate.providerName, credentialKind: credentials.kind, - credentialSource: credentials.source, - outcome: "succeeded", + credentialSource: candidate.credentialSource, + outcome: "empty-result", }); - clearProviderIssue(provider.id, "small_model_tasks"); - return { result, attempts }; + continue; } - attempts.push({ - providerId: provider.id, - providerName: provider.name, + providerId: candidate.providerId, + providerName: candidate.providerName, credentialKind: credentials.kind, - credentialSource: credentials.source, - outcome: "empty-result", + credentialSource: candidate.credentialSource, + outcome: "succeeded", }); - clearProviderIssue(provider.id, "small_model_tasks"); + return { result, attempts }; } catch (error) { - const reason = error instanceof Error ? error.message : String(error); - const issue = classifyProviderIssue({ - providerId: provider.id, - errorMessage: reason, - }); attempts.push({ - providerId: provider.id, - providerName: provider.name, + providerId: candidate.providerId, + providerName: candidate.providerName, credentialKind: credentials.kind, - credentialSource: credentials.source, - issue, + credentialSource: candidate.credentialSource, outcome: "failed", - reason, + reason: error instanceof Error ? error.message : String(error), }); - reportProviderIssue(provider.id, issue); } } - return { - result: null, - attempts, - }; + return { result: null, attempts }; } diff --git a/apps/desktop/src/lib/ai/provider-diagnostics.ts b/apps/desktop/src/lib/ai/provider-diagnostics.ts deleted file mode 100644 index be8206edc0e..00000000000 --- a/apps/desktop/src/lib/ai/provider-diagnostics.ts +++ /dev/null @@ -1,89 +0,0 @@ -import type { - ProviderCapability, - ProviderDiagnostic, - ProviderId, - ProviderIssue, -} from "shared/ai/provider-status"; - -const DIAGNOSTIC_CAPABILITIES: ProviderCapability[] = [ - "chat", - "small_model_tasks", - "workspace_titles", -]; - -const diagnostics = new Map(); - -function getDiagnosticKey( - providerId: ProviderId, - capability: ProviderCapability, -): string { - return `${providerId}:${capability}`; -} - -function getEmptyDiagnostic(providerId: ProviderId): ProviderDiagnostic { - return { - providerId, - issue: null, - updatedAt: null, - }; -} - -export function getProviderDiagnostic( - providerId: ProviderId, - capability?: ProviderCapability, -): ProviderDiagnostic { - if (capability) { - return ( - diagnostics.get(getDiagnosticKey(providerId, capability)) ?? - getEmptyDiagnostic(providerId) - ); - } - - let latestDiagnostic: ProviderDiagnostic | null = null; - for (const supportedCapability of DIAGNOSTIC_CAPABILITIES) { - const diagnostic = diagnostics.get( - getDiagnosticKey(providerId, supportedCapability), - ); - if (!diagnostic) { - continue; - } - if ( - latestDiagnostic === null || - (diagnostic.updatedAt ?? 0) > (latestDiagnostic.updatedAt ?? 0) - ) { - latestDiagnostic = diagnostic; - } - } - - return latestDiagnostic ?? getEmptyDiagnostic(providerId); -} - -export function getProviderDiagnostics(): ProviderDiagnostic[] { - return [getProviderDiagnostic("anthropic"), getProviderDiagnostic("openai")]; -} - -export function reportProviderIssue( - providerId: ProviderId, - issue: ProviderIssue, -): void { - const capability = issue.capability ?? "chat"; - diagnostics.set(getDiagnosticKey(providerId, capability), { - providerId, - issue, - updatedAt: Date.now(), - }); -} - -export function clearProviderIssue( - providerId: ProviderId, - capability?: ProviderCapability, -): void { - if (capability) { - diagnostics.delete(getDiagnosticKey(providerId, capability)); - return; - } - - for (const supportedCapability of DIAGNOSTIC_CAPABILITIES) { - diagnostics.delete(getDiagnosticKey(providerId, supportedCapability)); - } -} diff --git a/apps/desktop/src/lib/trpc/routers/changes/git-operations.ts b/apps/desktop/src/lib/trpc/routers/changes/git-operations.ts index baa65db7f8e..55fc0432410 100644 --- a/apps/desktop/src/lib/trpc/routers/changes/git-operations.ts +++ b/apps/desktop/src/lib/trpc/routers/changes/git-operations.ts @@ -1,7 +1,4 @@ -import { - generateTitleFromMessage, - generateTitleFromMessageWithStreamingModel, -} from "@superset/chat/server/desktop"; +import { generateTitleFromMessage } from "@superset/chat/server/desktop"; import { TRPCError } from "@trpc/server"; import { callSmallModel } from "lib/ai/call-small-model"; import { z } from "zod"; @@ -660,20 +657,8 @@ export const createGitOperationsRouter = () => { : f.diff; const { result } = await callSmallModel({ - invoke: async ({ - model, - credentials, - providerId, - providerName, - }) => { - if (providerId === "openai" && credentials.kind === "oauth") { - return generateTitleFromMessageWithStreamingModel({ - message: `File: ${f.path}\n\n${truncatedDiff}`, - model: model as never, - instructions: PHASE1_INSTRUCTIONS, - }); - } - return generateTitleFromMessage({ + invoke: async ({ model, providerId, providerName }) => + generateTitleFromMessage({ message: `File: ${f.path}\n\n${truncatedDiff}`, agentModel: model, agentId: `commit-file-summary-${providerId}`, @@ -683,8 +668,7 @@ export const createGitOperationsRouter = () => { surface: "commit-file-summary", provider: providerName, }, - }); - }, + }), }); return `${f.path}: ${result ?? "変更あり"}`; @@ -708,16 +692,8 @@ export const createGitOperationsRouter = () => { "日本語で簡潔なconventional commitメッセージを生成してください。コミットメッセージの行のみを返してください。"; const { result, attempts } = await callSmallModel({ - invoke: async ({ model, credentials, providerId, providerName }) => { - if (providerId === "openai" && credentials.kind === "oauth") { - return generateTitleFromMessageWithStreamingModel({ - message: PHASE2_PROMPT, - model: model as never, - instructions: PHASE2_INSTRUCTIONS, - }); - } - - return generateTitleFromMessage({ + invoke: async ({ model, providerId, providerName }) => + generateTitleFromMessage({ message: PHASE2_PROMPT, agentModel: model, agentId: `commit-message-${providerId}`, @@ -727,8 +703,7 @@ export const createGitOperationsRouter = () => { surface: "commit-message-generation", provider: providerName, }, - }); - }, + }), }); if (!result) { diff --git a/apps/desktop/src/lib/trpc/routers/index.ts b/apps/desktop/src/lib/trpc/routers/index.ts index 810eedb5158..c561e873f8c 100644 --- a/apps/desktop/src/lib/trpc/routers/index.ts +++ b/apps/desktop/src/lib/trpc/routers/index.ts @@ -24,7 +24,6 @@ import { createGitHubMetricsRouter } from "./github-metrics"; import { createHostServiceCoordinatorRouter } from "./host-service-coordinator"; import { createLanguageServicesRouter } from "./language-services"; import { createMenuRouter } from "./menu"; -import { createModelProvidersRouter } from "./model-providers"; import { createNotificationsRouter } from "./notifications"; import { createPermissionsRouter } from "./permissions"; import { createPortsRouter } from "./ports"; @@ -56,7 +55,6 @@ export const createAppRouter = ( auth: createAuthRouter(), autoUpdate: createAutoUpdateRouter(), cache: createCacheRouter(), - modelProviders: createModelProvidersRouter(), window: createWindowRouter(getWindow), projects: createProjectsRouter(getWindow), workspaces: createWorkspacesRouter(), diff --git a/apps/desktop/src/lib/trpc/routers/model-providers/index.ts b/apps/desktop/src/lib/trpc/routers/model-providers/index.ts deleted file mode 100644 index 511b3a4ef5c..00000000000 --- a/apps/desktop/src/lib/trpc/routers/model-providers/index.ts +++ /dev/null @@ -1,51 +0,0 @@ -import { - clearProviderIssue, - getProviderDiagnostic, -} from "lib/ai/provider-diagnostics"; -import { - deriveModelProviderStatus, - type ProviderId, -} from "shared/ai/provider-status"; -import { z } from "zod"; -import { publicProcedure, router } from "../.."; -import { chatService } from "../chat-service"; - -const providerIdSchema = z.enum(["anthropic", "openai"]); - -async function getProviderStatuses() { - const [anthropicAuthStatus, openAIAuthStatus] = await Promise.all([ - chatService.getAnthropicAuthStatus(), - chatService.getOpenAIAuthStatus(), - ]); - - return [ - deriveModelProviderStatus({ - providerId: "anthropic", - authStatus: anthropicAuthStatus, - diagnostic: getProviderDiagnostic("anthropic"), - }), - deriveModelProviderStatus({ - providerId: "openai", - authStatus: openAIAuthStatus, - diagnostic: getProviderDiagnostic("openai"), - }), - ]; -} - -export const createModelProvidersRouter = () => { - return router({ - getStatuses: publicProcedure.query(async () => { - return getProviderStatuses(); - }), - clearIssue: publicProcedure - .input(z.object({ providerId: providerIdSchema })) - .mutation(({ input }: { input: { providerId: ProviderId } }) => { - clearProviderIssue(input.providerId); - return { success: true }; - }), - }); -}; - -export type ModelProvidersRouter = ReturnType< - typeof createModelProvidersRouter ->; diff --git a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts index 6997101917e..7ad12334bab 100644 --- a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts +++ b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-branch-name.ts @@ -1,18 +1,12 @@ -import { - generateTitleFromMessage, - generateTitleFromMessageWithStreamingModel, -} from "@superset/chat/server/desktop"; -import { callSmallModel } from "lib/ai/call-small-model"; +import { generateTitleFromMessage } from "@superset/chat/server/desktop"; +import { getSmallModel } from "@superset/chat/server/shared"; import { sanitizeBranchNameWithMaxLength } from "shared/utils/branch"; const BRANCH_NAME_INSTRUCTIONS = "Generate a concise git branch name (2-4 words, kebab-case, descriptive). Return ONLY the branch name, nothing else."; const MAX_CONFLICT_RESOLUTION_ATTEMPTS = 1000; -const INITIAL_CONFLICT_SUFFIX = 2; // Start at -2 since -1 is implicit (no suffix) +const INITIAL_CONFLICT_SUFFIX = 2; -/** - * Checks if a branch name conflicts with existing branches (case-insensitive) - */ function hasConflict( branchName: string, existingBranchesSet: Set, @@ -20,29 +14,21 @@ function hasConflict( return existingBranchesSet.has(branchName.toLowerCase()); } -/** - * Resolves branch name conflicts by appending a number (-2, -3, etc.) - * IMPORTANT: Checks conflicts with prefix applied to match server behavior - */ function resolveConflict( baseName: string, existingBranches: string[], branchPrefix: string | undefined, ): string { - // Apply prefix to match what the server will do const prefixedBase = branchPrefix ? `${branchPrefix}/${baseName}` : baseName; - - // Quick check without creating Set (covers 90% of cases where no conflict exists) const lowerPrefixedBase = prefixedBase.toLowerCase(); const hasInitialConflict = existingBranches.some( (b) => b.toLowerCase() === lowerPrefixedBase, ); if (!hasInitialConflict) { - return baseName; // Return unprefixed - server will apply prefix + return baseName; } - // Only create Set if we need to loop through conflicts const existingSet = new Set(existingBranches.map((b) => b.toLowerCase())); let counter = INITIAL_CONFLICT_SUFFIX; @@ -64,54 +50,34 @@ function resolveConflict( : candidate; } - return candidate; // Return unprefixed - server will apply prefix + return candidate; } -/** - * Generates an AI-powered branch name from a user prompt with automatic conflict resolution. - * - * @param prompt - User's workspace description - * @param existingBranches - List of existing branch names to check for conflicts - * @param branchPrefix - Optional prefix that will be applied by the server (e.g., "avi") - * @returns Generated branch name WITHOUT prefix (server will apply it) or null if generation fails - * @throws Error if conflict resolution exceeds max attempts - */ export async function generateBranchNameFromPrompt( prompt: string, existingBranches: string[], branchPrefix?: string, ): Promise { - const { result } = await callSmallModel({ - invoke: async ({ credentials, providerId, providerName, model }) => { - if (providerId === "openai" && credentials.kind === "oauth") { - return generateTitleFromMessageWithStreamingModel({ - message: prompt, - model: model as never, - instructions: BRANCH_NAME_INSTRUCTIONS, - }); - } - - return generateTitleFromMessage({ - message: prompt, - agentModel: model, - agentId: `branch-namer-${providerId}`, - agentName: "Branch Namer", - instructions: BRANCH_NAME_INSTRUCTIONS, - tracingContext: { - surface: "workspace-branch-name", - provider: providerName, - }, - }); - }, - }); + const model = getSmallModel(); + if (!model) return null; - if (result !== null && result !== undefined) { - const sanitized = sanitizeBranchNameWithMaxLength(result); - if (sanitized) { - // Resolve conflicts with prefix applied (matches server behavior) - return resolveConflict(sanitized, existingBranches, branchPrefix); - } + let generated: string | null; + try { + generated = await generateTitleFromMessage({ + message: prompt, + agentModel: model, + agentId: "branch-namer", + agentName: "Branch Namer", + instructions: BRANCH_NAME_INSTRUCTIONS, + tracingContext: { surface: "workspace-branch-name" }, + }); + } catch (error) { + console.warn("[generateBranchNameFromPrompt] generation failed:", error); + return null; } - return null; + if (!generated) return null; + const sanitized = sanitizeBranchNameWithMaxLength(generated); + if (!sanitized) return null; + return resolveConflict(sanitized, existingBranches, branchPrefix); } diff --git a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.test.ts b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.test.ts index 0636b859f23..5fa3df6dff0 100644 --- a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.test.ts +++ b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.test.ts @@ -1,17 +1,9 @@ import { afterAll, beforeEach, describe, expect, it, mock } from "bun:test"; -import type { SmallModelAttempt } from "lib/ai/call-small-model"; -const callSmallModelMock = mock((async () => ({ - result: null, - attempts: [], -})) as (...args: unknown[]) => Promise<{ - result: string | null; - attempts: SmallModelAttempt[]; -}>); -const generateTitleFromMessageMock = mock( - (async () => null) as (...args: unknown[]) => Promise, +const getSmallModelMock = mock( + (() => null) as (...args: unknown[]) => unknown | null, ); -const generateTitleFromMessageWithStreamingModelMock = mock( +const generateTitleFromMessageMock = mock( (async () => null) as (...args: unknown[]) => Promise, ); @@ -31,15 +23,12 @@ type SelectedWorkspace = } | null; -mock.module("lib/ai/call-small-model", () => ({ - callSmallModel: callSmallModelMock, +mock.module("@superset/chat/server/shared", () => ({ + getSmallModel: getSmallModelMock, })); mock.module("@superset/chat/server/desktop", () => ({ - __esModule: true, generateTitleFromMessage: generateTitleFromMessageMock, - generateTitleFromMessageWithStreamingModel: - generateTitleFromMessageWithStreamingModelMock, })); mock.module("drizzle-orm", () => ({ @@ -89,100 +78,32 @@ const { describe("generateWorkspaceNameFromPrompt", () => { beforeEach(() => { - callSmallModelMock.mockClear(); - callSmallModelMock.mockImplementation(async () => ({ - result: null, - attempts: [], - })); + getSmallModelMock.mockClear(); + getSmallModelMock.mockReturnValue(null); + generateTitleFromMessageMock.mockClear(); + generateTitleFromMessageMock.mockResolvedValue(null); selectGetMock.mockReset(); selectGetMock.mockReturnValue(null); updateRunMock.mockReset(); updateRunMock.mockReturnValue({ changes: 1 }); localDbMock.select.mockClear(); localDbMock.update.mockClear(); - generateTitleFromMessageMock.mockClear(); - generateTitleFromMessageWithStreamingModelMock.mockClear(); }); - it("falls back to a prompt-derived title when no providers are available", async () => { + it("falls back to a prompt-derived title when no model is available", async () => { await expect( generateWorkspaceNameFromPrompt(" debug prod rename failure "), ).resolves.toEqual({ name: "debug prod rename failure", usedPromptFallback: true, warning: - "No model account was connected, so a prompt-based title was used.", - }); - }); - - it("uses the last relevant provider issue in the fallback warning", async () => { - callSmallModelMock.mockImplementation(async () => ({ - result: null, - attempts: [ - { - providerId: "anthropic", - providerName: "Anthropic", - outcome: "failed", - issue: { - code: "unknown_error", - message: "Anthropic could not complete this request", - }, - }, - { - providerId: "openai", - providerName: "OpenAI", - outcome: "failed", - issue: { - code: "missing_scope", - message: "OpenAI needs permission model.request", - }, - }, - ], - })); - - await expect( - generateWorkspaceNameFromPrompt("rename this workspace from prompt"), - ).resolves.toEqual({ - name: "rename this workspace from prompt", - usedPromptFallback: true, - warning: - "OpenAI needs permission model.request, so a prompt-based title was used.", + "A prompt-based title was used because model naming was unavailable.", }); }); - it("uses streaming title generation for OpenAI OAuth naming", async () => { - generateTitleFromMessageWithStreamingModelMock.mockResolvedValue( - "Checking In", - ); - callSmallModelMock.mockImplementationOnce((async ({ - invoke, - }: { - invoke: (context: { - providerId: "openai"; - providerName: string; - model: { id: string }; - credentials: { - apiKey: string; - kind: "oauth"; - source: string; - }; - }) => Promise; - }) => ({ - result: await invoke({ - providerId: "openai", - providerName: "OpenAI", - model: { id: "openai-model" }, - credentials: { - apiKey: "oauth-token", - kind: "oauth", - source: "auth-storage", - }, - }), - attempts: [], - })) as (...args: unknown[]) => Promise<{ - result: string | null; - attempts: SmallModelAttempt[]; - }>); + it("returns the model-generated title when a model is available", async () => { + getSmallModelMock.mockReturnValueOnce({ id: "test-model" }); + generateTitleFromMessageMock.mockResolvedValueOnce("Checking In"); await expect( generateWorkspaceNameFromPrompt("hey boss how are you"), @@ -190,21 +111,19 @@ describe("generateWorkspaceNameFromPrompt", () => { name: "Checking In", usedPromptFallback: false, }); - expect(generateTitleFromMessageWithStreamingModelMock).toHaveBeenCalledWith( - { - message: "hey boss how are you", - model: { id: "openai-model" }, - instructions: "You generate concise workspace titles.", - }, - ); - expect(generateTitleFromMessageMock).not.toHaveBeenCalled(); + expect(generateTitleFromMessageMock).toHaveBeenCalledWith({ + message: "hey boss how are you", + agentModel: { id: "test-model" }, + agentId: "workspace-namer", + agentName: "Workspace Namer", + instructions: "You generate concise workspace titles.", + tracingContext: { surface: "workspace-auto-name" }, + }); }); it("preserves empty-string model results instead of forcing fallback", async () => { - callSmallModelMock.mockImplementationOnce(async () => ({ - result: "", - attempts: [], - })); + getSmallModelMock.mockReturnValueOnce({ id: "test-model" }); + generateTitleFromMessageMock.mockResolvedValueOnce(""); await expect( generateWorkspaceNameFromPrompt("name this workspace"), @@ -213,6 +132,20 @@ describe("generateWorkspaceNameFromPrompt", () => { usedPromptFallback: false, }); }); + + it("falls back when generation throws", async () => { + getSmallModelMock.mockReturnValueOnce({ id: "test-model" }); + generateTitleFromMessageMock.mockRejectedValueOnce(new Error("boom")); + + await expect( + generateWorkspaceNameFromPrompt("rename this workspace from prompt"), + ).resolves.toEqual({ + name: "rename this workspace from prompt", + usedPromptFallback: true, + warning: + "A prompt-based title was used because model naming was unavailable.", + }); + }); }); afterAll(() => { @@ -221,11 +154,10 @@ afterAll(() => { describe("attemptWorkspaceAutoRenameFromPrompt", () => { beforeEach(() => { - callSmallModelMock.mockClear(); - callSmallModelMock.mockImplementation(async () => ({ - result: null, - attempts: [], - })); + getSmallModelMock.mockClear(); + getSmallModelMock.mockReturnValue(null); + generateTitleFromMessageMock.mockClear(); + generateTitleFromMessageMock.mockResolvedValue(null); selectGetMock.mockReset(); selectGetMock.mockReturnValue(null); updateRunMock.mockReset(); @@ -252,7 +184,7 @@ describe("attemptWorkspaceAutoRenameFromPrompt", () => { status: "skipped", reason: "workspace-named", }); - expect(callSmallModelMock).not.toHaveBeenCalled(); + expect(getSmallModelMock).not.toHaveBeenCalled(); expect(localDbMock.update).not.toHaveBeenCalled(); }); @@ -264,10 +196,8 @@ describe("attemptWorkspaceAutoRenameFromPrompt", () => { isUnnamed: true, deletingAt: null, }); - callSmallModelMock.mockImplementationOnce(async () => ({ - result: "", - attempts: [], - })); + getSmallModelMock.mockReturnValueOnce({ id: "test-model" }); + generateTitleFromMessageMock.mockResolvedValueOnce(""); await expect( attemptWorkspaceAutoRenameFromPrompt({ diff --git a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts index 02849066beb..1bb06606be5 100644 --- a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts +++ b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts @@ -1,13 +1,7 @@ -import { - generateTitleFromMessage, - generateTitleFromMessageWithStreamingModel, -} from "@superset/chat/server/desktop"; +import { generateTitleFromMessage } from "@superset/chat/server/desktop"; +import { getSmallModel } from "@superset/chat/server/shared"; import { workspaces } from "@superset/local-db"; import { and, eq, isNull } from "drizzle-orm"; -import { - callSmallModel, - type SmallModelAttempt, -} from "lib/ai/call-small-model"; import { localDb } from "main/lib/local-db"; import { deriveWorkspaceTitleFromPrompt } from "shared/utils/workspace-naming"; import { getWorkspaceAutoRenameDecision } from "./workspace-auto-rename"; @@ -32,66 +26,39 @@ export type WorkspaceAutoRenameResult = warning?: string; }; +const FALLBACK_WARNING = + "A prompt-based title was used because model naming was unavailable."; + export async function generateWorkspaceNameFromPrompt(prompt: string): Promise<{ name: string | null; usedPromptFallback: boolean; warning?: string; }> { - const { result, attempts } = await callSmallModel({ - invoke: async ({ credentials, providerId, providerName, model }) => { - if (providerId === "openai" && credentials.kind === "oauth") { - return generateTitleFromMessageWithStreamingModel({ - message: prompt, - model: model as never, - instructions: "You generate concise workspace titles.", - }); - } - - return generateTitleFromMessage({ + const model = getSmallModel(); + if (model) { + try { + const generated = await generateTitleFromMessage({ message: prompt, agentModel: model, - agentId: `workspace-namer-${providerId}`, + agentId: "workspace-namer", agentName: "Workspace Namer", instructions: "You generate concise workspace titles.", - tracingContext: { - surface: "workspace-auto-name", - provider: providerName, - }, + tracingContext: { surface: "workspace-auto-name" }, }); - }, - }); - if (result !== null && result !== undefined) { - return { name: result, usedPromptFallback: false }; - } - - for (const attempt of attempts) { - if (attempt.outcome === "failed") { - console.error( - `[workspace-ai-name] ${attempt.providerName} title generation failed`, - { - issue: attempt.issue ?? null, - reason: attempt.reason ?? null, - }, - ); - continue; - } - if (attempt.outcome === "unsupported-credentials") { - console.info( - `[workspace-ai-name] Skipping ${attempt.providerName} for title generation`, - { - issue: attempt.issue ?? attempt.reason, - }, - ); + if (generated !== null && generated !== undefined) { + return { name: generated, usedPromptFallback: false }; + } + } catch (error) { + console.error("[workspace-ai-name] title generation failed", error); } } const fallbackTitle = deriveWorkspaceTitleFromPrompt(prompt); if (fallbackTitle) { - console.info("[workspace-ai-name] Falling back to prompt-derived title"); return { name: fallbackTitle, usedPromptFallback: true, - warning: buildWorkspaceAutoNameFallbackWarning(attempts), + warning: FALLBACK_WARNING, }; } @@ -203,33 +170,3 @@ export async function attemptWorkspaceAutoRenameFromPrompt({ : "workspace-name-changed", }; } - -function buildWorkspaceAutoNameFallbackWarning( - attempts: SmallModelAttempt[], -): string { - if (attempts.length === 0) { - return "No model account was connected, so a prompt-based title was used."; - } - - for (let index = attempts.length - 1; index >= 0; index -= 1) { - const attempt = attempts[index]; - if (attempt.outcome === "expired-credentials") { - return `${attempt.issue?.message ?? `${attempt.providerName} needs to be reconnected`}, so a prompt-based title was used.`; - } - if (attempt.outcome === "failed") { - return `${attempt.issue?.message ?? `${attempt.providerName} couldn't generate a title`}, so a prompt-based title was used.`; - } - if (attempt.outcome === "unsupported-credentials") { - return `${attempt.issue?.message ?? "No compatible model account was available"}, so a prompt-based title was used.`; - } - } - - const missingCredentials = attempts.every( - (attempt) => attempt.outcome === "missing-credentials", - ); - if (missingCredentials) { - return "No model account was connected, so a prompt-based title was used."; - } - - return "A prompt-based title was used because model naming was unavailable."; -} diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/AnthropicOAuthDialog/AnthropicOAuthDialog.tsx b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/AnthropicOAuthDialog/AnthropicOAuthDialog.tsx index 2e4c508f083..882e6a3d64f 100644 --- a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/AnthropicOAuthDialog/AnthropicOAuthDialog.tsx +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/AnthropicOAuthDialog/AnthropicOAuthDialog.tsx @@ -1,157 +1,24 @@ -import { Button } from "@superset/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@superset/ui/dialog"; -import { InputGroup, InputGroupInput } from "@superset/ui/input-group"; -import { Label } from "@superset/ui/label"; - -interface AnthropicOAuthDialogProps { - open: boolean; - authUrl: string | null; - code: string; - errorMessage: string | null; - isPreparing: boolean; - isPending: boolean; - canDisconnect: boolean; - onOpenChange: (open: boolean) => void; - onCodeChange: (value: string) => void; - onOpenAuthUrl: () => void; - onCopyAuthUrl: () => void; - onDisconnect: () => void; - onRetry: () => void; - onSubmit: () => void; -} - -export function AnthropicOAuthDialog({ - open, - authUrl, - code, - errorMessage, - isPreparing, - isPending, - canDisconnect, - onOpenChange, - onCodeChange, - onOpenAuthUrl, - onCopyAuthUrl, - onDisconnect, - onRetry, - onSubmit, -}: AnthropicOAuthDialogProps) { - const hasAuthUrl = Boolean(authUrl); - const showCodeInput = hasAuthUrl || isPending; - const primaryLabel = isPending - ? "Connecting..." - : hasAuthUrl - ? "Continue" - : "Try again"; - +import { OAuthDialog, type OAuthDialogProps } from "../OAuthDialog"; + +const ANTHROPIC_PROVIDER: OAuthDialogProps["provider"] = { + title: "Connect Anthropic", + description: + "Approve access in your browser, then paste the callback URL or `code#state` here.", + codeLabel: "Authorization code", + codePlaceholder: "Paste callback URL or code#state", + codeHint: + "Anthropic usually returns a full callback URL. Pasting either format works.", + preparingLabel: "Preparing Anthropic browser login...", +}; + +type AnthropicOAuthDialogProps = Omit; + +export function AnthropicOAuthDialog(props: AnthropicOAuthDialogProps) { return ( - - - - Connect Anthropic - - Approve access in your browser, then paste the callback URL or - `code#state` here. - - - -
- {isPreparing ? ( -
- Preparing Anthropic browser login... -
- ) : null} - - {showCodeInput ? ( -
-
- - -
- -
- - - onCodeChange(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter" && code.trim()) { - onSubmit(); - } - }} - disabled={isPending} - className="h-11 font-mono" - autoFocus - /> - -

- Anthropic usually returns a full callback URL. Pasting either - format works. -

-
-
- ) : null} - - {errorMessage ? ( -

{errorMessage}

- ) : null} - -
- -
- - {canDisconnect ? ( - - ) : null} -
-
-
-
-
+ ); } diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/OAuthDialog.tsx b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/OAuthDialog.tsx new file mode 100644 index 00000000000..7820f431fba --- /dev/null +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/OAuthDialog.tsx @@ -0,0 +1,180 @@ +import { Button } from "@superset/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@superset/ui/dialog"; +import { InputGroup, InputGroupInput } from "@superset/ui/input-group"; +import { Label } from "@superset/ui/label"; +import { useCallback, useState } from "react"; + +export interface OAuthDialogProps { + provider: { + title: string; + description: string; + codeLabel: string; + codePlaceholder: string; + codeHint: string; + preparingLabel: string; + }; + open: boolean; + authUrl: string | null; + code: string; + errorMessage: string | null; + isPreparing?: boolean; + isPending: boolean; + canDisconnect: boolean; + requireCodeForSubmit?: boolean; + onOpenChange: (open: boolean) => void; + onCodeChange: (value: string) => void; + onOpenAuthUrl: () => void; + onCopyAuthUrl: () => void; + onDisconnect: () => void; + onRetry?: () => void; + onSubmit: () => void; +} + +export function OAuthDialog({ + provider, + open, + authUrl, + code, + errorMessage, + isPreparing, + isPending, + canDisconnect, + requireCodeForSubmit, + onOpenChange, + onCodeChange, + onOpenAuthUrl, + onCopyAuthUrl, + onDisconnect, + onRetry, + onSubmit, +}: OAuthDialogProps) { + const hasAuthUrl = Boolean(authUrl); + const showCodeInput = hasAuthUrl || isPending; + const canSubmit = + !isPreparing && + !isPending && + (!requireCodeForSubmit || code.trim().length > 0); + const [copied, setCopied] = useState(false); + const handleCopy = useCallback(() => { + onCopyAuthUrl(); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }, [onCopyAuthUrl]); + + return ( + + + + {provider.title} + {provider.description} + + +
+ {isPreparing ? ( +
+ {provider.preparingLabel} +
+ ) : null} + + {showCodeInput ? ( +
+
+ + +
+ +
+ + + onCodeChange(event.target.value)} + onKeyDown={(event) => { + if ( + event.key === "Enter" && + !event.nativeEvent.isComposing && + canSubmit + ) { + onSubmit(); + } + }} + disabled={isPending} + className="h-11 font-mono text-sm" + autoFocus + /> + +

+ {provider.codeHint} +

+
+
+ ) : !isPreparing ? ( +
+ {provider.preparingLabel} +
+ ) : null} + + {errorMessage ? ( +

{errorMessage}

+ ) : null} + +
+ +
+ + {canDisconnect ? ( + + ) : null} +
+
+
+
+
+ ); +} diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/index.ts b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/index.ts new file mode 100644 index 00000000000..602028a06a5 --- /dev/null +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OAuthDialog/index.ts @@ -0,0 +1 @@ +export { OAuthDialog, type OAuthDialogProps } from "./OAuthDialog"; diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OpenAIOAuthDialog/OpenAIOAuthDialog.tsx b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OpenAIOAuthDialog/OpenAIOAuthDialog.tsx index 2edd4a525b6..a356da92ecc 100644 --- a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OpenAIOAuthDialog/OpenAIOAuthDialog.tsx +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/components/OpenAIOAuthDialog/OpenAIOAuthDialog.tsx @@ -1,152 +1,17 @@ -import { Button } from "@superset/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogHeader, - DialogTitle, -} from "@superset/ui/dialog"; -import { InputGroup, InputGroupInput } from "@superset/ui/input-group"; -import { Label } from "@superset/ui/label"; - -const OPENAI_OAUTH_CALLBACK_URL = "http://localhost:1455/auth/callback"; - -interface OpenAIOAuthDialogProps { - open: boolean; - authUrl: string | null; - code: string; - errorMessage: string | null; - isPending: boolean; - canDisconnect: boolean; - onOpenChange: (open: boolean) => void; - onCodeChange: (value: string) => void; - onOpenAuthUrl: () => void; - onCopyAuthUrl: () => void; - onDisconnect: () => void; - onSubmit: () => void; -} - -export function OpenAIOAuthDialog({ - open, - authUrl, - code, - errorMessage, - isPending, - canDisconnect, - onOpenChange, - onCodeChange, - onOpenAuthUrl, - onCopyAuthUrl, - onDisconnect, - onSubmit, -}: OpenAIOAuthDialogProps) { - const hasAuthUrl = Boolean(authUrl); - - return ( - - - - Connect OpenAI - - Approve access in your browser. If the callback does not finish, - paste the redirected callback URL below. - - - -
-
- Tip: OpenAI - OAuth usually completes automatically after browser approval. If you - land on {`${OPENAI_OAUTH_CALLBACK_URL}?...`}, copy that - full URL and paste it below. -
- -
- - -
- - {hasAuthUrl ? ( -
-

OAuth URL

-

- {authUrl} -

-
- ) : ( -
- OAuth URL not ready yet. -
- )} - -
- - - onCodeChange(event.target.value)} - onKeyDown={(event) => { - if (event.key === "Enter" && !event.nativeEvent.isComposing) { - onSubmit(); - } - }} - disabled={isPending} - className="h-11 font-mono text-xs sm:text-sm" - autoFocus - /> - -

- Leave this empty if browser login finishes on its own. -

-
- - {errorMessage ? ( -

{errorMessage}

- ) : null} - -
- -
- - {canDisconnect ? ( - - ) : null} -
-
-
-
-
- ); +import { OAuthDialog, type OAuthDialogProps } from "../OAuthDialog"; + +const OPENAI_PROVIDER: OAuthDialogProps["provider"] = { + title: "Connect OpenAI", + description: + "Approve access in your browser. If the callback does not finish, paste the redirected callback URL below.", + codeLabel: "Callback URL (optional)", + codePlaceholder: "Paste callback URL", + codeHint: "Leave this empty if browser login finishes on its own.", + preparingLabel: "Preparing OpenAI browser login...", +}; + +type OpenAIOAuthDialogProps = Omit; + +export function OpenAIOAuthDialog(props: OpenAIOAuthDialogProps) { + return ; } diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useAnthropicOAuth/useAnthropicOAuth.ts b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useAnthropicOAuth/useAnthropicOAuth.ts index 967c84413af..313dc58829e 100644 --- a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useAnthropicOAuth/useAnthropicOAuth.ts +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useAnthropicOAuth/useAnthropicOAuth.ts @@ -1,7 +1,6 @@ import { chatServiceTrpc } from "@superset/chat/client"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCopyToClipboard } from "renderer/hooks/useCopyToClipboard"; -import { electronTrpc } from "renderer/lib/electron-trpc"; import { electronTrpcClient } from "renderer/lib/trpc-client"; function getErrorMessage(error: unknown, fallback: string): string { @@ -75,7 +74,6 @@ export function useAnthropicOAuth({ const autoSubmitTimeoutRef = useRef | null>( null, ); - const electronUtils = electronTrpc.useUtils(); const { data: anthropicStatus, refetch: refetchAnthropicStatus } = chatServiceTrpc.auth.getAnthropicStatus.useQuery(); @@ -177,10 +175,6 @@ export function useAnthropicOAuth({ onModelSelectorOpenChange(true); try { - await electronTrpcClient.modelProviders.clearIssue.mutate({ - providerId: "anthropic", - }); - await electronUtils.modelProviders.getStatuses.invalidate(); await refetchAnthropicStatus(); await onAuthStateChange?.(); } catch (error) { @@ -193,7 +187,6 @@ export function useAnthropicOAuth({ [ clearAutoSubmitTimeout, completeAnthropicOAuthMutation, - electronUtils.modelProviders.getStatuses.invalidate, onAuthStateChange, onModelSelectorOpenChange, refetchAnthropicStatus, @@ -223,10 +216,6 @@ export function useAnthropicOAuth({ onModelSelectorOpenChange(true); try { - await electronTrpcClient.modelProviders.clearIssue.mutate({ - providerId: "anthropic", - }); - await electronUtils.modelProviders.getStatuses.invalidate(); await refetchAnthropicStatus(); await onAuthStateChange?.(); } catch (error) { @@ -237,7 +226,6 @@ export function useAnthropicOAuth({ } }, [ disconnectAnthropicOAuthMutation, - electronUtils.modelProviders.getStatuses.invalidate, onAuthStateChange, onModelSelectorOpenChange, refetchAnthropicStatus, diff --git a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useOpenAIOAuth/useOpenAIOAuth.ts b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useOpenAIOAuth/useOpenAIOAuth.ts index 59a74b5c903..2506adbbc6b 100644 --- a/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useOpenAIOAuth/useOpenAIOAuth.ts +++ b/apps/desktop/src/renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useOpenAIOAuth/useOpenAIOAuth.ts @@ -1,7 +1,6 @@ import { chatServiceTrpc } from "@superset/chat/client"; import { useCallback, useEffect, useMemo, useState } from "react"; import { useCopyToClipboard } from "renderer/hooks/useCopyToClipboard"; -import { electronTrpc } from "renderer/lib/electron-trpc"; import { electronTrpcClient } from "renderer/lib/trpc-client"; function getErrorMessage(error: unknown, fallback: string): string { @@ -47,7 +46,6 @@ export function useOpenAIOAuth({ const [oauthCode, setOauthCode] = useState(""); const [oauthError, setOauthError] = useState(null); const [hasPendingOAuthSession, setHasPendingOAuthSession] = useState(false); - const electronUtils = electronTrpc.useUtils(); const { data: openAIStatus, refetch: refetchOpenAIStatus } = chatServiceTrpc.auth.getOpenAIStatus.useQuery(); @@ -92,13 +90,14 @@ export function useOpenAIOAuth({ setOauthCode(""); setHasPendingOAuthSession(true); setOauthDialogOpen(true); + await openExternalUrl(result.url); } catch (error) { setOauthDialogOpen(true); setOauthError( getErrorMessage(error, "Failed to start OpenAI OAuth flow"), ); } - }, [startOpenAIOAuthMutation]); + }, [openExternalUrl, startOpenAIOAuthMutation]); const { copyToClipboard } = useCopyToClipboard(); const copyOAuthUrl = useCallback(() => { @@ -110,10 +109,6 @@ export function useOpenAIOAuth({ const syncOpenAIAuthUi = useCallback( async (action: "complete" | "disconnect") => { try { - await electronTrpcClient.modelProviders.clearIssue.mutate({ - providerId: "openai", - }); - await electronUtils.modelProviders.getStatuses.invalidate(); await refetchOpenAIStatus(); } catch (error) { console.error( @@ -122,7 +117,7 @@ export function useOpenAIOAuth({ ); } }, - [electronUtils.modelProviders.getStatuses.invalidate, refetchOpenAIStatus], + [refetchOpenAIStatus], ); const completeOpenAIOAuth = useCallback(async () => { diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/ModelsSettings.tsx b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/ModelsSettings.tsx index eb0c5b319a0..df9cb02305a 100644 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/ModelsSettings.tsx +++ b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/ModelsSettings.tsx @@ -1,10 +1,12 @@ import { chatServiceTrpc } from "@superset/chat/client"; +import { Badge } from "@superset/ui/badge"; import { Button } from "@superset/ui/button"; import { Collapsible, CollapsibleContent, CollapsibleTrigger, } from "@superset/ui/collapsible"; +import { claudeIcon } from "@superset/ui/icons/preset-icons"; import { Input } from "@superset/ui/input"; import { toast } from "@superset/ui/sonner"; import { Switch } from "@superset/ui/switch"; @@ -15,19 +17,17 @@ import { AnthropicOAuthDialog } from "renderer/components/Chat/ChatInterface/com import { OpenAIOAuthDialog } from "renderer/components/Chat/ChatInterface/components/ModelPicker/components/OpenAIOAuthDialog"; import { useAnthropicOAuth } from "renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useAnthropicOAuth"; import { useOpenAIOAuth } from "renderer/components/Chat/ChatInterface/components/ModelPicker/hooks/useOpenAIOAuth"; -import { electronTrpc } from "renderer/lib/electron-trpc"; import { isItemVisible, SETTING_ITEM_ID, type SettingItemId, } from "../../../utils/settings-search"; -import { AccountCard } from "./components/AccountCard"; import { ConfigRow } from "./components/ConfigRow"; import { SettingsSection } from "./components/SettingsSection"; import { buildAnthropicEnvText, EMPTY_ANTHROPIC_FORM, - getProviderSubtitle, + getProviderAction, getStatusBadge, parseAnthropicForm, resolveProviderStatus, @@ -52,7 +52,6 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { SETTING_ITEM_ID.MODELS_NEXT_EDIT, visibleItems, ); - const [apiKeysOpen, setApiKeysOpen] = useState(true); const [overrideOpen, setOverrideOpen] = useState(true); const [nextEditAdvancedOpen, setNextEditAdvancedOpen] = useState(true); const [openAIApiKeyInput, setOpenAIApiKeyInput] = useState(""); @@ -68,14 +67,6 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { stopText: "", }); - const { data: providerStatuses, refetch: refetchProviderStatuses } = - electronTrpc.modelProviders.getStatuses.useQuery(); - const anthropicDiagnosticStatus = providerStatuses?.find( - (status) => status.providerId === "anthropic", - ); - const openAIDiagnosticStatus = providerStatuses?.find( - (status) => status.providerId === "openai", - ); const { data: anthropicAuthStatus, refetch: refetchAnthropicAuthStatus } = chatServiceTrpc.auth.getAnthropicStatus.useQuery(); const { data: openAIAuthStatus, refetch: refetchOpenAIAuthStatus } = @@ -106,8 +97,6 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { chatServiceTrpc.auth.clearInceptionApiKey.useMutation(); const setNextEditConfigMutation = chatServiceTrpc.nextEdit.setConfig.useMutation(); - const clearProviderIssueMutation = - electronTrpc.modelProviders.clearIssue.useMutation(); const { isStartingOAuth: isStartingAnthropicOAuth, @@ -116,10 +105,7 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { } = useAnthropicOAuth({ ...DIALOG_CONTEXT, onAuthStateChange: async () => { - await Promise.all([ - refetchAnthropicAuthStatus(), - refetchProviderStatuses(), - ]); + await refetchAnthropicAuthStatus(); }, }); const { @@ -167,9 +153,8 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { resolveProviderStatus({ providerId: "anthropic", authStatus: anthropicAuthStatus, - diagnosticStatus: anthropicDiagnosticStatus, }), - [anthropicAuthStatus, anthropicDiagnosticStatus], + [anthropicAuthStatus], ); const openAIStatus = useMemo( @@ -177,19 +162,10 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { resolveProviderStatus({ providerId: "openai", authStatus: openAIAuthStatus, - diagnosticStatus: openAIDiagnosticStatus, }), - [openAIAuthStatus, openAIDiagnosticStatus], + [openAIAuthStatus], ); - const anthropicSubtitle = useMemo( - () => getProviderSubtitle("anthropic", anthropicStatus), - [anthropicStatus], - ); - const openAISubtitle = useMemo( - () => getProviderSubtitle("openai", openAIStatus), - [openAIStatus], - ); const anthropicBadge = useMemo( () => getStatusBadge(anthropicStatus), [anthropicStatus], @@ -199,9 +175,6 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { [openAIStatus], ); - const clearProviderIssue = (providerId: "anthropic" | "openai") => - clearProviderIssueMutation.mutateAsync({ providerId }); - const formatTokenCount = (value: number) => { return new Intl.NumberFormat("en-US").format(value); }; @@ -226,8 +199,6 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { await Promise.all([ refetchAnthropicEnvConfig(), refetchAnthropicAuthStatus(), - clearProviderIssue("anthropic"), - refetchProviderStatuses(), ]); toast.success("Anthropic settings updated"); return true; @@ -243,11 +214,7 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { try { await setAnthropicApiKeyMutation.mutateAsync({ apiKey }); setAnthropicApiKeyInput(""); - await Promise.all([ - refetchAnthropicAuthStatus(), - clearProviderIssue("anthropic"), - refetchProviderStatuses(), - ]); + await refetchAnthropicAuthStatus(); toast.success("Anthropic API key updated"); } catch (error) { toast.error(error instanceof Error ? error.message : "Failed to save"); @@ -260,11 +227,7 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { try { await setOpenAIApiKeyMutation.mutateAsync({ apiKey }); setOpenAIApiKeyInput(""); - await Promise.all([ - refetchOpenAIAuthStatus(), - clearProviderIssue("openai"), - refetchProviderStatuses(), - ]); + await refetchOpenAIAuthStatus(); toast.success("OpenAI API key updated"); } catch (error) { toast.error(error instanceof Error ? error.message : "Failed to save"); @@ -347,63 +310,29 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { status, startOAuth, isStartingOAuth, - canDisconnect, onDisconnect, }: { status: typeof anthropicStatus | typeof openAIStatus; startOAuth: () => Promise; isStartingOAuth: boolean; - canDisconnect: boolean; onDisconnect: () => void; }) => { - if (!status || status.connectionState === "disconnected") { - return ( - - ); - } - - if (status.issue?.remediation === "reconnect") { - return ( - - ); - } - - if (canDisconnect) { + const action = getProviderAction(status); + if (!action) return null; + if (action.kind === "logout") { return ( ); } - return ( ); }; @@ -420,167 +349,165 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) {
{showAnthropic ? ( - - + } + > +
+
+
+

OAuth

+ {anthropicBadge ? ( + + {anthropicBadge.label} + + ) : null} +
+ {renderProviderAction({ + status: anthropicStatus, + startOAuth: startAnthropicOAuth, + isStartingOAuth: isStartingAnthropicOAuth, + onDisconnect: async () => { + if (anthropicStatus?.authMethod === "oauth") { + anthropicOAuthDialog.onDisconnect(); + } else { + await clearAnthropicApiKeyMutation.mutateAsync(); + setAnthropicApiKeyInput(""); + } + await refetchAnthropicAuthStatus(); + }, + })} +
+ { + setAnthropicApiKeyInput(event.target.value); + }} + placeholder={ + anthropicStatus?.authMethod === "api_key" + ? "Saved Anthropic API key" + : "sk-ant-..." + } + className="font-mono" + disabled={isSavingAnthropicApiKey} + /> + } + onSave={() => { + void saveAnthropicApiKey(); + }} + onClear={() => { + const nextForm = { ...anthropicForm, apiKey: "" }; + void (async () => { + try { + await clearAnthropicApiKeyMutation.mutateAsync(); + setAnthropicApiKeyInput(""); + setAnthropicForm(nextForm); + await refetchAnthropicAuthStatus(); + toast.success("Anthropic API key cleared"); + } catch (error) { + toast.error( + error instanceof Error + ? error.message + : "Failed to clear", + ); + } + })(); + }} + showSave={anthropicApiKeyInput.trim().length > 0} + disableSave={isSavingAnthropicApiKey} + showClear={anthropicStatus?.authMethod === "api_key"} + disableClear={isSavingAnthropicApiKey} + /> +
) : null} {showOpenAI ? ( - - + + } + > +
+
+
+

OAuth

+ {openAIBadge ? ( + + {openAIBadge.label} + + ) : null} +
+ {renderProviderAction({ + status: openAIStatus, + startOAuth: startOpenAIOAuth, + isStartingOAuth: isStartingOpenAIOAuth, + onDisconnect: async () => { + if (openAIStatus?.authMethod === "oauth") { + openAIOAuthDialog.onDisconnect(); + } else { + await clearOpenAIApiKeyMutation.mutateAsync(); + setOpenAIApiKeyInput(""); + } + await refetchOpenAIAuthStatus(); + }, + })} +
+ { + setOpenAIApiKeyInput(event.target.value); + }} + placeholder={ + openAIStatus?.authMethod === "api_key" + ? "Saved OpenAI API key" + : "sk-..." + } + className="font-mono" + disabled={isSavingOpenAIConfig} + /> + } + onSave={() => { + void saveOpenAIApiKey(); + }} + onClear={() => { + void (async () => { + try { + await clearOpenAIApiKeyMutation.mutateAsync(); + setOpenAIApiKeyInput(""); + await refetchOpenAIAuthStatus(); + toast.success("OpenAI API key cleared"); + } catch (error) { + toast.error( + error instanceof Error + ? error.message + : "Failed to clear", + ); + } + })(); + }} + showSave={openAIApiKeyInput.trim().length > 0} + disableSave={isSavingOpenAIConfig} + showClear={openAIStatus?.authMethod === "api_key"} + disableClear={isSavingOpenAIConfig} + /> +
) : null} - -
- - - - - {showAnthropic ? ( - { - setAnthropicApiKeyInput(event.target.value); - }} - placeholder={ - anthropicStatus?.authMethod === "api_key" - ? "Saved Anthropic API key" - : "sk-ant-..." - } - className="font-mono" - disabled={isSavingAnthropicApiKey} - /> - } - onSave={() => { - void saveAnthropicApiKey(); - }} - onClear={() => { - const nextForm = { ...anthropicForm, apiKey: "" }; - void (async () => { - try { - await clearAnthropicApiKeyMutation.mutateAsync(); - setAnthropicApiKeyInput(""); - setAnthropicForm(nextForm); - await Promise.all([ - refetchAnthropicAuthStatus(), - clearProviderIssue("anthropic"), - refetchProviderStatuses(), - ]); - toast.success("Anthropic API key cleared"); - } catch (error) { - toast.error( - error instanceof Error - ? error.message - : "Failed to clear", - ); - } - })(); - }} - disableSave={ - isSavingAnthropicApiKey || - anthropicApiKeyInput.trim().length === 0 - } - disableClear={ - isSavingAnthropicApiKey || - anthropicStatus?.authMethod !== "api_key" - } - /> - ) : null} - {showOpenAI ? ( - { - setOpenAIApiKeyInput(event.target.value); - }} - placeholder={ - openAIStatus?.authMethod === "api_key" - ? "Saved OpenAI API key" - : "sk-..." - } - className="font-mono" - disabled={isSavingOpenAIConfig} - /> - } - onSave={() => { - void saveOpenAIApiKey(); - }} - onClear={() => { - void (async () => { - try { - await clearOpenAIApiKeyMutation.mutateAsync(); - setOpenAIApiKeyInput(""); - await Promise.all([ - refetchOpenAIAuthStatus(), - clearProviderIssue("openai"), - refetchProviderStatuses(), - ]); - toast.success("OpenAI API key cleared"); - } catch (error) { - toast.error( - error instanceof Error - ? error.message - : "Failed to clear", - ); - } - })(); - }} - disableSave={ - isSavingOpenAIConfig || - openAIApiKeyInput.trim().length === 0 - } - disableClear={ - isSavingOpenAIConfig || - openAIStatus?.authMethod !== "api_key" - } - /> - ) : null} - -
-
- {showAnthropic ? (
@@ -595,111 +522,113 @@ export function ModelsSettings({ visibleItems }: ModelsSettingsProps) { Override Provider - - { - setAnthropicForm((current) => ({ - ...current, - authToken: event.target.value, - })); - }} - placeholder="sk-ant-..." - className="font-mono" - disabled={isSavingAnthropicConfig} - /> - } - onSave={() => { - void saveAnthropicForm(); - }} - onClear={() => { - const nextForm = { ...anthropicForm, authToken: "" }; - setAnthropicForm(nextForm); - void saveAnthropicForm(nextForm); - }} - disableSave={isSavingAnthropicConfig} - disableClear={ - isSavingAnthropicConfig || - anthropicForm.authToken.length === 0 - } - /> - { - setAnthropicForm((current) => ({ - ...current, - baseUrl: event.target.value, - })); - }} - placeholder="https://api.anthropic.com" - className="font-mono" - disabled={isSavingAnthropicConfig} - /> - } - onSave={() => { - void saveAnthropicForm(); - }} - onClear={() => { - const nextForm = { ...anthropicForm, baseUrl: "" }; - setAnthropicForm(nextForm); - void saveAnthropicForm(nextForm); - }} - disableSave={isSavingAnthropicConfig} - disableClear={ - isSavingAnthropicConfig || - anthropicForm.baseUrl.length === 0 - } - /> - { - setAnthropicForm((current) => ({ - ...current, - extraEnv: event.target.value, - })); - }} - placeholder={ - "CLAUDE_CODE_USE_BEDROCK=1\nAWS_REGION=us-east-1" - } - className="min-h-24 font-mono text-xs" - disabled={isSavingAnthropicConfig} - /> - } - onSave={() => { - void saveAnthropicForm(); - }} - onClear={ - hasAnthropicConfig - ? () => { - const nextForm = { - ...anthropicForm, - extraEnv: "", - }; - setAnthropicForm(nextForm); - void saveAnthropicForm(nextForm); + +
+ { + setAnthropicForm((current) => ({ + ...current, + authToken: event.target.value, + })); + }} + placeholder="sk-ant-..." + className="font-mono" + disabled={isSavingAnthropicConfig} + /> + } + onSave={() => { + void saveAnthropicForm(); + }} + onClear={() => { + const nextForm = { ...anthropicForm, authToken: "" }; + setAnthropicForm(nextForm); + void saveAnthropicForm(nextForm); + }} + disableSave={isSavingAnthropicConfig} + disableClear={ + isSavingAnthropicConfig || + anthropicForm.authToken.length === 0 + } + /> + { + setAnthropicForm((current) => ({ + ...current, + baseUrl: event.target.value, + })); + }} + placeholder="https://api.anthropic.com" + className="font-mono" + disabled={isSavingAnthropicConfig} + /> + } + onSave={() => { + void saveAnthropicForm(); + }} + onClear={() => { + const nextForm = { ...anthropicForm, baseUrl: "" }; + setAnthropicForm(nextForm); + void saveAnthropicForm(nextForm); + }} + disableSave={isSavingAnthropicConfig} + disableClear={ + isSavingAnthropicConfig || + anthropicForm.baseUrl.length === 0 + } + /> + { + setAnthropicForm((current) => ({ + ...current, + extraEnv: event.target.value, + })); + }} + placeholder={ + "CLAUDE_CODE_USE_BEDROCK=1\nAWS_REGION=us-east-1" } - : undefined - } - clearLabel="Clear" - disableSave={isSavingAnthropicConfig} - disableClear={ - isSavingAnthropicConfig || - anthropicForm.extraEnv.length === 0 - } - /> + className="min-h-24 font-mono text-xs" + disabled={isSavingAnthropicConfig} + /> + } + onSave={() => { + void saveAnthropicForm(); + }} + onClear={ + hasAnthropicConfig + ? () => { + const nextForm = { + ...anthropicForm, + extraEnv: "", + }; + setAnthropicForm(nextForm); + void saveAnthropicForm(nextForm); + } + : undefined + } + clearLabel="Clear" + disableSave={isSavingAnthropicConfig} + disableClear={ + isSavingAnthropicConfig || + anthropicForm.extraEnv.length === 0 + } + /> +
diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/AccountCard.tsx b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/AccountCard.tsx deleted file mode 100644 index 98e49024db8..00000000000 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/AccountCard.tsx +++ /dev/null @@ -1,41 +0,0 @@ -import { Badge } from "@superset/ui/badge"; -import { cn } from "@superset/ui/utils"; -import type { ReactNode } from "react"; - -interface AccountCardProps { - title: string; - subtitle: string; - badge?: string; - badgeVariant?: "secondary" | "outline" | "destructive"; - actions?: ReactNode; - muted?: boolean; -} - -export function AccountCard({ - title, - subtitle, - badge, - badgeVariant = "secondary", - actions, - muted = false, -}: AccountCardProps) { - return ( -
-
-
-

{title}

-

{subtitle}

-
-
- {badge ? {badge} : null} - {actions} -
-
-
- ); -} diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/index.ts b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/index.ts deleted file mode 100644 index 88333bd623e..00000000000 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/AccountCard/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { AccountCard } from "./AccountCard"; diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/ConfigRow/ConfigRow.tsx b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/ConfigRow/ConfigRow.tsx index d0e0b24c61c..953ca4dd156 100644 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/ConfigRow/ConfigRow.tsx +++ b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/ConfigRow/ConfigRow.tsx @@ -1,4 +1,5 @@ import { Button } from "@superset/ui/button"; +import { cn } from "@superset/ui/utils"; import type { ReactNode } from "react"; interface ConfigRowProps { @@ -9,8 +10,11 @@ interface ConfigRowProps { onClear?: () => void; saveLabel?: string; clearLabel?: string; + showSave?: boolean; + showClear?: boolean; disableSave?: boolean; disableClear?: boolean; + className?: string; } export function ConfigRow({ @@ -21,11 +25,14 @@ export function ConfigRow({ onClear, saveLabel = "Save", clearLabel = "Clear", + showSave = true, + showClear = true, disableSave, disableClear, + className, }: ConfigRowProps) { return ( -
+

{title}

@@ -36,7 +43,7 @@ export function ConfigRow({
{field}
- {onClear ? ( + {onClear && showClear ? ( diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/SettingsSection/SettingsSection.tsx b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/SettingsSection/SettingsSection.tsx index 508ec0eabd9..19d4c485c0e 100644 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/SettingsSection/SettingsSection.tsx +++ b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/components/SettingsSection/SettingsSection.tsx @@ -2,6 +2,7 @@ import type { ReactNode } from "react"; interface SettingsSectionProps { title: string; + icon?: ReactNode; description?: string; action?: ReactNode; children: ReactNode; @@ -9,6 +10,7 @@ interface SettingsSectionProps { export function SettingsSection({ title, + icon, description, action, children, @@ -17,7 +19,10 @@ export function SettingsSection({
-

{title}

+

+ {icon} + {title} +

{description ? (

{description}

) : null} diff --git a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/utils.ts b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/utils.ts index b14c60c63d2..b77b5b0bf75 100644 --- a/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/utils.ts +++ b/apps/desktop/src/renderer/routes/_authenticated/settings/models/components/ModelsSettings/utils.ts @@ -84,7 +84,7 @@ export function getProviderSubtitle( return status.issue.message; } if (!status || status.connectionState === "disconnected") { - return "No account connected"; + return ""; } if (status.source === "external" && status.authMethod === "oauth") { return EXTERNAL_OAUTH_LABELS[providerId]; @@ -101,8 +101,8 @@ export function getProviderSubtitle( export function getStatusBadge( status: ModelProviderStatus | undefined, ): { label: string; variant: "secondary" | "outline" | "destructive" } | null { - if (!status) { - return null; + if (!status || status.connectionState === "disconnected") { + return { label: "Not connected", variant: "outline" }; } if (status.issue?.code === "expired") { return { label: "Expired", variant: "destructive" }; @@ -119,22 +119,32 @@ export function getStatusBadge( export function resolveProviderStatus(params: { providerId: ProviderId; authStatus?: AuthStatusLike; - diagnosticStatus?: ModelProviderStatus; }): ModelProviderStatus | undefined { - const { providerId, authStatus, diagnosticStatus } = params; - if (!authStatus) { - return diagnosticStatus; - } + const { providerId, authStatus } = params; + if (!authStatus) return undefined; + return deriveModelProviderStatus({ providerId, authStatus }); +} - return deriveModelProviderStatus({ - providerId, - authStatus, - diagnostic: { - providerId, - issue: authStatus.authenticated - ? (diagnosticStatus?.issue ?? null) - : null, - updatedAt: null, - }, - }); +export type ProviderAction = + | { kind: "connect" } + | { kind: "reconnect" } + | { kind: "logout" } + | null; + +/** + * Single source of truth for the provider action button. + */ +export function getProviderAction( + status: ModelProviderStatus | undefined, +): ProviderAction { + if (!status || status.connectionState === "disconnected") { + return { kind: "connect" }; + } + if (status.issue?.remediation === "reconnect") { + return { kind: "reconnect" }; + } + if (status.connectionState === "connected") { + return { kind: "logout" }; + } + return { kind: "connect" }; } diff --git a/apps/desktop/src/shared/ai/provider-status.test.ts b/apps/desktop/src/shared/ai/provider-status.test.ts index 4a2ed63d85e..4c0e1e99267 100644 --- a/apps/desktop/src/shared/ai/provider-status.test.ts +++ b/apps/desktop/src/shared/ai/provider-status.test.ts @@ -1,23 +1,8 @@ import { describe, expect, it } from "bun:test"; -import { - deriveModelProviderStatus, - type ProviderDiagnostic, -} from "./provider-status"; +import { deriveModelProviderStatus } from "./provider-status"; describe("deriveModelProviderStatus", () => { - it("keeps a connected provider connected when only capability diagnostics fail", () => { - const diagnostic: ProviderDiagnostic = { - providerId: "openai", - issue: { - code: "missing_scope", - capability: "small_model_tasks", - remediation: "check_permissions", - scope: "api.responses.write", - message: "OpenAI needs permission api.responses.write", - }, - updatedAt: Date.now(), - }; - + it("marks an authenticated provider without issues as connected", () => { const status = deriveModelProviderStatus({ providerId: "openai", authStatus: { @@ -26,14 +11,15 @@ describe("deriveModelProviderStatus", () => { source: "managed", issue: null, }, - diagnostic, }); expect(status.connectionState).toBe("connected"); - expect(status.issue?.code).toBe("missing_scope"); - expect(status.capabilities.canUseChat).toBe(true); - expect(status.capabilities.canGenerateWorkspaceTitle).toBe(false); - expect(status.capabilities.canUseSmallModelTasks).toBe(false); + expect(status.issue).toBeNull(); + expect(status.capabilities).toEqual({ + canUseChat: true, + canGenerateWorkspaceTitle: true, + canUseSmallModelTasks: true, + }); }); it("treats expired auth as needs attention and disables all capabilities", () => { @@ -55,4 +41,20 @@ describe("deriveModelProviderStatus", () => { canUseSmallModelTasks: false, }); }); + + it("reports disconnected for providers with no source and no auth", () => { + const status = deriveModelProviderStatus({ + providerId: "openai", + authStatus: { + authenticated: false, + method: null, + source: null, + issue: null, + }, + }); + + expect(status.connectionState).toBe("disconnected"); + expect(status.issue).toBeNull(); + expect(status.capabilities.canUseChat).toBe(false); + }); }); diff --git a/apps/desktop/src/shared/ai/provider-status.ts b/apps/desktop/src/shared/ai/provider-status.ts index c74e27088af..910ebc2f659 100644 --- a/apps/desktop/src/shared/ai/provider-status.ts +++ b/apps/desktop/src/shared/ai/provider-status.ts @@ -5,40 +5,14 @@ export type ProviderConnectionState = | "disconnected" | "needs_attention"; -export type ProviderCapability = - | "chat" - | "workspace_titles" - | "small_model_tasks"; +export type ProviderRemediation = "reconnect" | "add_api_key"; -export type ProviderRemediation = - | "reconnect" - | "check_permissions" - | "check_billing" - | "add_api_key" - | "try_again"; - -export type ProviderIssueCode = - | "expired" - | "missing_scope" - | "forbidden" - | "quota_exceeded" - | "network_error" - | "unsupported_credentials" - | "empty_result" - | "unknown_error"; +export type ProviderIssueCode = "expired"; export interface ProviderIssue { code: ProviderIssueCode; message: string; - capability?: ProviderCapability; remediation?: ProviderRemediation; - scope?: string | null; -} - -export interface ProviderDiagnostic { - providerId: ProviderId; - issue: ProviderIssue | null; - updatedAt: number | null; } export interface AuthStatusLike { @@ -69,72 +43,6 @@ export function getProviderName(providerId: ProviderId): string { return providerId === "anthropic" ? "Anthropic" : "OpenAI"; } -export function classifyProviderIssue(params: { - providerId: ProviderId; - errorMessage: string; -}): ProviderIssue { - const { providerId, errorMessage } = params; - const normalized = errorMessage.trim(); - const lower = normalized.toLowerCase(); - - const missingScopeMatch = normalized.match( - /Missing scopes:\s*([A-Za-z0-9._,\s-]+)/i, - ); - if (missingScopeMatch || lower.includes("insufficient permissions")) { - const scope = - missingScopeMatch?.[1]?.trim().replace(/[.,;:]+$/, "") ?? null; - const providerName = getProviderName(providerId); - return { - code: "missing_scope", - capability: "small_model_tasks", - remediation: "check_permissions", - scope, - message: scope - ? `${providerName} needs permission ${scope}` - : `${providerName} is missing permission for this action`, - }; - } - - if (lower.includes("quota") || lower.includes("insufficient_quota")) { - return { - code: "quota_exceeded", - capability: "small_model_tasks", - remediation: "check_billing", - message: `${getProviderName(providerId)} quota or billing needs attention`, - }; - } - - if (lower.includes("forbidden") || lower.includes("status: 403")) { - return { - code: "forbidden", - capability: "small_model_tasks", - remediation: "check_permissions", - message: `${getProviderName(providerId)} denied this request`, - }; - } - - if ( - lower.includes("timed out") || - lower.includes("network") || - lower.includes("econn") || - lower.includes("fetch failed") - ) { - return { - code: "network_error", - capability: "small_model_tasks", - remediation: "try_again", - message: `${getProviderName(providerId)} request failed due to a network error`, - }; - } - - return { - code: "unknown_error", - capability: "small_model_tasks", - remediation: "try_again", - message: `${getProviderName(providerId)} could not complete this request`, - }; -} - function getIssueFromAuthStatus( providerId: ProviderId, authStatus: AuthStatusLike, @@ -142,7 +50,6 @@ function getIssueFromAuthStatus( if (authStatus.issue === "expired") { return { code: "expired", - capability: "chat", remediation: "reconnect", message: `${getProviderName(providerId)} session expired`, }; @@ -154,53 +61,24 @@ function getIssueFromAuthStatus( export function deriveModelProviderStatus(params: { providerId: ProviderId; authStatus: AuthStatusLike; - diagnostic?: ProviderDiagnostic | null; }): ModelProviderStatus { - const { providerId, authStatus, diagnostic } = params; - const authIssue = getIssueFromAuthStatus(providerId, authStatus); - const issue = authIssue ?? diagnostic?.issue ?? null; + const { providerId, authStatus } = params; + const issue = getIssueFromAuthStatus(providerId, authStatus); let connectionState: ProviderConnectionState = "disconnected"; if (authStatus.authenticated) { - connectionState = authIssue ? "needs_attention" : "connected"; - } else if (authIssue || authStatus.source !== null) { + connectionState = issue ? "needs_attention" : "connected"; + } else if (issue || authStatus.source !== null) { connectionState = "needs_attention"; } + const canUse = authStatus.authenticated && !issue; const capabilities: ProviderCapabilities = { - canUseChat: authStatus.authenticated, - canGenerateWorkspaceTitle: authStatus.authenticated, - canUseSmallModelTasks: authStatus.authenticated, + canUseChat: canUse, + canGenerateWorkspaceTitle: canUse, + canUseSmallModelTasks: canUse, }; - if (issue) { - switch (issue.code) { - case "expired": - capabilities.canUseChat = false; - capabilities.canGenerateWorkspaceTitle = false; - capabilities.canUseSmallModelTasks = false; - break; - case "missing_scope": - case "forbidden": - case "quota_exceeded": - case "network_error": - case "unsupported_credentials": - case "empty_result": - case "unknown_error": - if (issue.capability === "chat") { - capabilities.canUseChat = false; - } - if ( - issue.capability === "small_model_tasks" || - issue.capability === "workspace_titles" - ) { - capabilities.canGenerateWorkspaceTitle = false; - capabilities.canUseSmallModelTasks = false; - } - break; - } - } - return { providerId, connectionState, diff --git a/packages/chat/package.json b/packages/chat/package.json index 91806e673ef..610bf047928 100644 --- a/packages/chat/package.json +++ b/packages/chat/package.json @@ -23,6 +23,10 @@ "./server/hono": { "types": "./src/server/hono/index.ts", "default": "./src/server/hono/index.ts" + }, + "./server/shared": { + "types": "./src/server/shared/index.ts", + "default": "./src/server/shared/index.ts" } }, "scripts": { @@ -32,7 +36,7 @@ "dependencies": { "@ai-sdk/anthropic": "^3.0.43", "@ai-sdk/openai": "3.0.36", - "@mastra/core": "1.16.0", + "@mastra/core": "1.25.0", "@mastra/mcp": "1.3.1", "@superset/trpc": "workspace:*", "@superset/workspace-fs": "workspace:*", @@ -40,7 +44,7 @@ "@trpc/server": "^11.7.1", "ai": "^6.0.0", "hono": "^4.8.5", - "mastracode": "0.9.2", + "mastracode": "0.14.0", "superjson": "^2.2.5", "zod": "^4.3.5" }, diff --git a/packages/chat/src/server/desktop/auth/anthropic/anthropic.ts b/packages/chat/src/server/desktop/auth/anthropic/anthropic.ts index b28ef85d5cc..75c5fff8209 100644 --- a/packages/chat/src/server/desktop/auth/anthropic/anthropic.ts +++ b/packages/chat/src/server/desktop/auth/anthropic/anthropic.ts @@ -168,7 +168,7 @@ export function getCredentialsFromKeychain(): ClaudeCredentials | null { return null; } -export function getCredentialsFromAuthStorage(): ClaudeCredentials | null { +export async function getCredentialsFromAuthStorage(): Promise { try { const authStorage = createAuthStorage(); authStorage.reload(); @@ -187,18 +187,22 @@ export function getCredentialsFromAuthStorage(): ClaudeCredentials | null { }; } - if ( - credential.type === "oauth" && - typeof credential.access === "string" && - credential.access.trim().length > 0 - ) { + if (credential.type === "oauth") { + // mastracode's getApiKey triggers refreshToken() when expires <= now, + // and persists the refreshed credential back into auth storage. + const accessToken = await authStorage.getApiKey( + ANTHROPIC_AUTH_PROVIDER_ID, + ); + if (!accessToken || accessToken.trim().length === 0) return null; + authStorage.reload(); + const refreshed = authStorage.get(ANTHROPIC_AUTH_PROVIDER_ID); return { - apiKey: credential.access.trim(), + apiKey: accessToken.trim(), source: "auth-storage", kind: "oauth", expiresAt: - typeof credential.expires === "number" - ? credential.expires + refreshed?.type === "oauth" && typeof refreshed.expires === "number" + ? refreshed.expires : undefined, }; } @@ -209,24 +213,22 @@ export function getCredentialsFromAuthStorage(): ClaudeCredentials | null { return null; } -export function getCredentialsFromAnySource(): ClaudeCredentials | null { - const resolvers = [ - getCredentialsFromConfig, - getCredentialsFromKeychain, - getCredentialsFromAuthStorage, - ]; +export async function getCredentialsFromAnySource(): Promise { + const syncResolvers = [getCredentialsFromConfig, getCredentialsFromKeychain]; let firstExpired: ClaudeCredentials | null = null; - for (const resolve of resolvers) { + for (const resolve of syncResolvers) { const credential = resolve(); - if (!credential) { - continue; - } - if (!isClaudeCredentialExpired(credential)) { - return credential; - } + if (!credential) continue; + if (!isClaudeCredentialExpired(credential)) return credential; firstExpired ??= credential; } + const storageCredential = await getCredentialsFromAuthStorage(); + if (storageCredential && !isClaudeCredentialExpired(storageCredential)) { + return storageCredential; + } + firstExpired ??= storageCredential ?? null; + return firstExpired; } diff --git a/packages/chat/src/server/desktop/chat-service/auth-storage-utils.ts b/packages/chat/src/server/desktop/chat-service/auth-storage-utils.ts index bcdff698c59..065e8226c25 100644 --- a/packages/chat/src/server/desktop/chat-service/auth-storage-utils.ts +++ b/packages/chat/src/server/desktop/chat-service/auth-storage-utils.ts @@ -1,3 +1,8 @@ +// WORKAROUND: backup/restore API keys across OAuth connect/disconnect. +// mastracode's resolveModel only reads API keys from the main authStorage +// slot, which OAuth login overwrites and disconnect clears. We back up to +// the dedicated apikey: slot before OAuth and restore after disconnect. +// Remove once mastra-ai/mastra#15483 lands and we bump mastracode. import type { AuthMethod, AuthStorageLike, @@ -16,10 +21,14 @@ export function setApiKeyForProvider( } authStorage.reload(); + // Store in main slot (mastracode's resolveModel reads from here). authStorage.set(providerId, { type: "api_key", key: trimmedApiKey, }); + // Also store in dedicated apikey: slot as a backup that survives + // OAuth connect/disconnect cycles. + authStorage.setStoredApiKey(providerId, trimmedApiKey); } export function clearApiKeyForProvider( @@ -27,12 +36,54 @@ export function clearApiKeyForProvider( providerId: string, ): void { authStorage.reload(); + + // Clear the dedicated backup slot. + if (authStorage.hasStoredApiKey(providerId)) { + authStorage.remove(`apikey:${providerId}`); + } + + // Clear the main slot if it holds an api_key. const credential = authStorage.get(providerId); - if (credential?.type !== "api_key") { - return; + if (credential?.type === "api_key") { + authStorage.remove(providerId); } +} - authStorage.remove(providerId); +/** + * Save the current API key to the backup slot before OAuth overwrites + * the main slot. Call this BEFORE authStorage.login(). + */ +export function backupApiKeyBeforeOAuth( + authStorage: AuthStorageLike, + providerId: string, +): void { + authStorage.reload(); + const credential = authStorage.get(providerId); + if ( + credential?.type === "api_key" && + credential.key.trim().length > 0 && + !authStorage.hasStoredApiKey(providerId) + ) { + authStorage.setStoredApiKey(providerId, credential.key.trim()); + } +} + +/** + * Restore the API key from the backup slot after OAuth is disconnected. + * Call this AFTER removing the OAuth credential from the main slot. + */ +export function restoreApiKeyAfterOAuthDisconnect( + authStorage: AuthStorageLike, + providerId: string, +): void { + authStorage.reload(); + const storedApiKey = authStorage.getStoredApiKey(providerId); + if (storedApiKey && storedApiKey.trim().length > 0) { + authStorage.set(providerId, { + type: "api_key", + key: storedApiKey.trim(), + }); + } } export function clearCredentialForProvider( @@ -60,5 +111,9 @@ export function resolveAuthMethodForProvider( if (credential?.type === "api_key" && credential.key.trim().length > 0) { return "api_key"; } + // Check the backup slot — API key may have been displaced by OAuth. + if (authStorage.hasStoredApiKey(providerId)) { + return "api_key"; + } return null; } diff --git a/packages/chat/src/server/desktop/chat-service/chat-service.test.ts b/packages/chat/src/server/desktop/chat-service/chat-service.test.ts index 41f71244718..9f42658e9ff 100644 --- a/packages/chat/src/server/desktop/chat-service/chat-service.test.ts +++ b/packages/chat/src/server/desktop/chat-service/chat-service.test.ts @@ -26,6 +26,16 @@ type FakeAuthStorage = { (providerId: string, callbacks: OAuthCallbacks) => Promise > >; + setStoredApiKey: ReturnType< + typeof mock<(providerId: string, key: string) => void> + >; + hasStoredApiKey: ReturnType boolean>>; + getStoredApiKey: ReturnType< + typeof mock<(providerId: string) => string | undefined> + >; + getApiKey: ReturnType< + typeof mock<(providerId: string) => Promise> + >; clear: () => void; }; @@ -41,6 +51,27 @@ function createFakeAuthStorage(): FakeAuthStorage { credentials.delete(providerId); }), login: mock(async () => {}), + setStoredApiKey: mock((providerId: string, key: string) => { + credentials.set(`apikey:${providerId}`, { + type: "api_key", + key, + } as Credential); + }), + hasStoredApiKey: mock((providerId: string) => + credentials.has(`apikey:${providerId}`), + ), + getStoredApiKey: mock((providerId: string) => { + const cred = credentials.get(`apikey:${providerId}`); + return cred?.type === "api_key" ? cred.key : undefined; + }), + getApiKey: mock(async (providerId: string) => { + const cred = credentials.get(providerId); + if (cred?.type === "oauth" && "access" in cred) { + return (cred as Record).access as string; + } + const stored = credentials.get(`apikey:${providerId}`); + return stored?.type === "api_key" ? stored.key : undefined; + }), clear: () => { credentials.clear(); }, @@ -102,8 +133,8 @@ mock.module("mastracode", () => ({ mock.module("../auth/anthropic", () => ({ getCredentialsFromConfig: () => anthropicConfigCredential, getCredentialsFromKeychain: () => anthropicKeychainCredential, - getCredentialsFromAnySource: () => null, - getCredentialsFromAuthStorage: () => null, + getCredentialsFromAnySource: async () => null, + getCredentialsFromAuthStorage: async () => null, getAnthropicProviderOptions: () => ({}), isClaudeCredentialExpired: (credential: { kind: "apiKey" | "oauth"; @@ -127,6 +158,10 @@ describe("ChatService OpenAI auth storage", () => { fakeAuthStorage.set.mockClear(); fakeAuthStorage.remove.mockClear(); fakeAuthStorage.login.mockClear(); + fakeAuthStorage.setStoredApiKey.mockClear(); + fakeAuthStorage.hasStoredApiKey.mockClear(); + fakeAuthStorage.getStoredApiKey.mockClear(); + fakeAuthStorage.getApiKey.mockClear(); anthropicConfigCredential = null; anthropicKeychainCredential = null; testSupersetHomeDir = mkdtempSync(join(tmpdir(), "chat-service-test-")); @@ -175,11 +210,11 @@ describe("ChatService OpenAI auth storage", () => { await chatService.clearOpenAIApiKey(); expect(createAuthStorageMock).toHaveBeenCalledTimes(1); - expect(fakeAuthStorage.set).toHaveBeenCalledWith("openai-codex", { - type: "api_key", - key: "test-key", - }); - expect(fakeAuthStorage.remove).toHaveBeenCalledWith("openai-codex"); + expect(fakeAuthStorage.setStoredApiKey).toHaveBeenCalledWith( + "openai-codex", + "test-key", + ); + expect(fakeAuthStorage.remove).toHaveBeenCalledWith("apikey:openai-codex"); }); it("stores and clears Anthropic API key in standalone auth storage", async () => { @@ -193,11 +228,11 @@ describe("ChatService OpenAI auth storage", () => { await chatService.clearAnthropicApiKey(); expect(createAuthStorageMock).toHaveBeenCalledTimes(1); - expect(fakeAuthStorage.set).toHaveBeenCalledWith("anthropic", { - type: "api_key", - key: "test-anthropic-key", - }); - expect(fakeAuthStorage.remove).toHaveBeenCalledWith("anthropic"); + expect(fakeAuthStorage.setStoredApiKey).toHaveBeenCalledWith( + "anthropic", + "test-anthropic-key", + ); + expect(fakeAuthStorage.remove).toHaveBeenCalledWith("apikey:anthropic"); }); it("persists Anthropic OAuth credentials to auth storage on completion", async () => { @@ -243,31 +278,7 @@ describe("ChatService OpenAI auth storage", () => { }), ); expect(result.expiresAt).toBe(oauthExpiresAt); - expect(chatService.getAnthropicAuthStatus().method).toBe("oauth"); - }); - - it("switches Anthropic status from oauth to api key when api key is saved", async () => { - const chatService = new ChatService(); - - fakeAuthStorage.login.mockImplementation( - async (providerId: string, callbacks: OAuthCallbacks) => { - callbacks.onAuth({ url: "https://claude.ai/oauth/authorize?foo=bar" }); - const code = await callbacks.onPrompt({ message: "Paste code" }); - expect(code).toBe("auth-code#state"); - fakeAuthStorage.set(providerId, { - type: "oauth", - access: "oauth-access-token", - expires: Date.now() + 60 * 60 * 1000, - }); - }, - ); - - await chatService.startAnthropicOAuth(); - await chatService.completeAnthropicOAuth({ code: "auth-code#state" }); - expect(chatService.getAnthropicAuthStatus().method).toBe("oauth"); - - await chatService.setAnthropicApiKey({ apiKey: " api-key " }); - expect(chatService.getAnthropicAuthStatus().method).toBe("api_key"); + expect((await chatService.getAnthropicAuthStatus()).method).toBe("oauth"); }); it("prefers a managed Anthropic API key over env-config credentials", async () => { @@ -284,7 +295,7 @@ describe("ChatService OpenAI auth storage", () => { ); expect(process.env.ANTHROPIC_API_KEY).toBeUndefined(); expect(process.env.ANTHROPIC_AUTH_TOKEN).toBeUndefined(); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "api_key", source: "managed", @@ -292,12 +303,12 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("ignores Anthropic runtime env credentials without managed auth", () => { + it("ignores Anthropic runtime env credentials without managed auth", async () => { const chatService = new ChatService(); process.env.ANTHROPIC_AUTH_TOKEN = "external-oauth-token"; - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: false, method: null, source: null, @@ -305,7 +316,7 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("prefers external Anthropic credentials over managed auth", () => { + it("prefers external Anthropic credentials over managed auth", async () => { const chatService = new ChatService(); anthropicConfigCredential = { @@ -313,12 +324,9 @@ describe("ChatService OpenAI auth storage", () => { source: "config", kind: "oauth", }; - fakeAuthStorage.set("anthropic", { - type: "api_key", - key: "managed-api-key", - }); + fakeAuthStorage.setStoredApiKey("anthropic", "managed-api-key"); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "oauth", source: "external", @@ -326,7 +334,7 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("surfaces hidden managed Anthropic OAuth when external Claude auth wins", () => { + it("surfaces hidden managed Anthropic OAuth when external Claude auth wins", async () => { const chatService = new ChatService(); anthropicConfigCredential = { @@ -340,7 +348,7 @@ describe("ChatService OpenAI auth storage", () => { expires: Date.now() + 60 * 60 * 1000, }); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "oauth", source: "external", @@ -349,16 +357,13 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("prefers managed Anthropic auth over runtime env credentials", () => { + it("prefers managed Anthropic auth over runtime env credentials", async () => { const chatService = new ChatService(); process.env.ANTHROPIC_AUTH_TOKEN = "external-oauth-token"; - fakeAuthStorage.set("anthropic", { - type: "api_key", - key: "managed-api-key", - }); + fakeAuthStorage.setStoredApiKey("anthropic", "managed-api-key"); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "api_key", source: "managed", @@ -366,7 +371,7 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("marks expired external Anthropic OAuth as expired", () => { + it("marks expired external Anthropic OAuth as expired", async () => { const chatService = new ChatService(); anthropicConfigCredential = { @@ -376,7 +381,7 @@ describe("ChatService OpenAI auth storage", () => { expiresAt: Date.now() - 1_000, }; - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: false, method: "oauth", source: "external", @@ -384,7 +389,7 @@ describe("ChatService OpenAI auth storage", () => { }); }); - it("falls back to managed Anthropic auth when external OAuth is expired", () => { + it("falls back to managed Anthropic auth when external OAuth is expired", async () => { const chatService = new ChatService(); anthropicConfigCredential = { @@ -393,12 +398,9 @@ describe("ChatService OpenAI auth storage", () => { kind: "oauth", expiresAt: Date.now() - 1_000, }; - fakeAuthStorage.set("anthropic", { - type: "api_key", - key: "managed-api-key", - }); + fakeAuthStorage.setStoredApiKey("anthropic", "managed-api-key"); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "api_key", source: "managed", @@ -415,7 +417,7 @@ describe("ChatService OpenAI auth storage", () => { expires: Date.now() + 60 * 60 * 1000, }); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "oauth", source: "managed", @@ -426,7 +428,7 @@ describe("ChatService OpenAI auth storage", () => { await chatService.disconnectAnthropicOAuth(); expect(fakeAuthStorage.remove).toHaveBeenCalledWith("anthropic"); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: false, method: null, source: null, @@ -447,10 +449,10 @@ describe("ChatService OpenAI auth storage", () => { ); expect(process.env.ANTHROPIC_AUTH_TOKEN).toBeUndefined(); expect(process.env.ANTHROPIC_API_KEY).toBeUndefined(); - expect(fakeAuthStorage.set).toHaveBeenCalledWith("anthropic", { - type: "api_key", - key: "gateway-token", - }); + expect(fakeAuthStorage.setStoredApiKey).toHaveBeenCalledWith( + "anthropic", + "gateway-token", + ); expect(chatService.getAnthropicEnvConfig()).toEqual({ envText: "ANTHROPIC_BASE_URL=https://ai-gateway.vercel.sh\nANTHROPIC_AUTH_TOKEN=gateway-token", @@ -459,7 +461,7 @@ describe("ChatService OpenAI auth storage", () => { ANTHROPIC_AUTH_TOKEN: "gateway-token", }, }); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "api_key", source: "managed", @@ -474,7 +476,7 @@ describe("ChatService OpenAI auth storage", () => { access: "oauth-access-token", expires: Date.now() + 60 * 60 * 1000, }); - expect(chatService.getAnthropicAuthStatus().method).toBe("oauth"); + expect((await chatService.getAnthropicAuthStatus()).method).toBe("oauth"); await chatService.setAnthropicEnvConfig({ envText: @@ -482,7 +484,7 @@ describe("ChatService OpenAI auth storage", () => { }); expect(fakeAuthStorage.remove).toHaveBeenCalledWith("anthropic"); - expect(chatService.getAnthropicAuthStatus().method).toBe("api_key"); + expect((await chatService.getAnthropicAuthStatus()).method).toBe("api_key"); }); it("persists Anthropic env config without API key/token", async () => { @@ -498,7 +500,7 @@ describe("ChatService OpenAI auth storage", () => { ANTHROPIC_BASE_URL: "https://ai-gateway.vercel.sh", }, }); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: false, method: null, source: null, @@ -522,11 +524,11 @@ describe("ChatService OpenAI auth storage", () => { await chatService.disconnectAnthropicOAuth(); expect(fakeAuthStorage.remove).toHaveBeenCalledWith("anthropic"); - expect(fakeAuthStorage.set).toHaveBeenLastCalledWith("anthropic", { - type: "api_key", - key: "gateway-token", - }); - expect(chatService.getAnthropicAuthStatus()).toEqual({ + expect(fakeAuthStorage.setStoredApiKey).toHaveBeenCalledWith( + "anthropic", + "gateway-token", + ); + expect(await chatService.getAnthropicAuthStatus()).toEqual({ authenticated: true, method: "api_key", source: "managed", @@ -545,10 +547,10 @@ describe("ChatService OpenAI auth storage", () => { expect(process.env.ANTHROPIC_AUTH_TOKEN).toBeUndefined(); expect(process.env.CLAUDE_CODE_USE_BEDROCK).toBe("1"); expect(process.env.AWS_REGION).toBe("us-east-1"); - expect(fakeAuthStorage.set).toHaveBeenCalledWith("anthropic", { - type: "api_key", - key: "env-key", - }); + expect(fakeAuthStorage.setStoredApiKey).toHaveBeenCalledWith( + "anthropic", + "env-key", + ); expect(chatService.getAnthropicEnvConfig()).toEqual({ envText: "ANTHROPIC_API_KEY=env-key\nCLAUDE_CODE_USE_BEDROCK=1\nAWS_REGION=us-east-1", @@ -572,12 +574,12 @@ describe("ChatService OpenAI auth storage", () => { expect(process.env.ANTHROPIC_BASE_URL).toBeUndefined(); expect(process.env.ANTHROPIC_AUTH_TOKEN).toBeUndefined(); expect(process.env.ANTHROPIC_API_KEY).toBeUndefined(); - expect(fakeAuthStorage.remove).toHaveBeenCalledWith("anthropic"); + expect(fakeAuthStorage.remove).toHaveBeenCalledWith("apikey:anthropic"); expect(chatService.getAnthropicEnvConfig()).toEqual({ envText: "", variables: {}, }); - expect(chatService.getAnthropicAuthStatus().method).toBeNull(); + expect((await chatService.getAnthropicAuthStatus()).method).toBeNull(); }); it("deletes previously applied pass-through env keys when settings change", async () => { diff --git a/packages/chat/src/server/desktop/chat-service/chat-service.ts b/packages/chat/src/server/desktop/chat-service/chat-service.ts index c6a2cacc808..449071bcef8 100644 --- a/packages/chat/src/server/desktop/chat-service/chat-service.ts +++ b/packages/chat/src/server/desktop/chat-service/chat-service.ts @@ -27,9 +27,11 @@ import { } from "./anthropic-env-config"; import type { AuthStatus } from "./auth-storage-types"; import { + backupApiKeyBeforeOAuth, clearApiKeyForProvider, clearCredentialForProvider, resolveAuthMethodForProvider, + restoreApiKeyAfterOAuthDisconnect, setApiKeyForProvider, } from "./auth-storage-utils"; import { @@ -103,11 +105,32 @@ export class ChatService { ); } - getAnthropicAuthStatus(): AuthStatus { + async getAnthropicAuthStatus(): Promise { const authStorage = this.getAuthStorage(); authStorage.reload(); - const storedCredential = authStorage.get(ANTHROPIC_AUTH_PROVIDER_ID); + let storedCredential = authStorage.get(ANTHROPIC_AUTH_PROVIDER_ID); const hasManagedOAuth = storedCredential?.type === "oauth"; + + // If managed OAuth is past its expiry, give mastracode a chance to + // refresh it before downgrading status to "expired". Mastracode's + // getApiKey uses the stored refresh token via the anthropic provider. + if ( + storedCredential?.type === "oauth" && + typeof storedCredential.expires === "number" && + storedCredential.expires <= Date.now() + ) { + try { + await authStorage.getApiKey(ANTHROPIC_AUTH_PROVIDER_ID); + authStorage.reload(); + storedCredential = authStorage.get(ANTHROPIC_AUTH_PROVIDER_ID); + } catch (error) { + // Refresh failed; fall through to expired-state handling below. + console.warn( + "[chat-service] Anthropic OAuth refresh failed, falling back to expired state:", + error, + ); + } + } const configCredential = getAnthropicCredentialsFromConfig(); const keychainCredential = getAnthropicCredentialsFromKeychain(); const externalCandidates = [configCredential, keychainCredential].filter( @@ -594,6 +617,7 @@ export class ChatService { } clearCredentialForProvider(authStorage, providerId); + restoreApiKeyAfterOAuthDisconnect(authStorage, providerId); removedProviderIds.push(providerId); } this.logAuthResolution("openai", { @@ -607,6 +631,9 @@ export class ChatService { async completeOpenAIOAuth(input: { code?: string; }): Promise<{ success: true }> { + for (const providerId of OPENAI_AUTH_PROVIDER_IDS) { + backupApiKeyBeforeOAuth(this.getAuthStorage(), providerId); + } await this.oauthFlowController.complete( this.getOpenAIOAuthFlowOptions(), input.code, @@ -696,6 +723,11 @@ export class ChatService { const credential = authStorage.get(ANTHROPIC_AUTH_PROVIDER_ID); if (credential?.type === "oauth") { clearCredentialForProvider(authStorage, ANTHROPIC_AUTH_PROVIDER_ID); + // Restore API key from backup slot if one was saved before OAuth connect. + restoreApiKeyAfterOAuthDisconnect( + authStorage, + ANTHROPIC_AUTH_PROVIDER_ID, + ); const config = getAnthropicEnvConfigFromDisk({ configPath: this.anthropicEnvConfigPath, }); @@ -715,6 +747,8 @@ export class ChatService { async completeAnthropicOAuth(input: { code?: string; }): Promise<{ success: true; expiresAt: number }> { + // Save API key to backup slot before OAuth overwrites the main slot. + backupApiKeyBeforeOAuth(this.getAuthStorage(), ANTHROPIC_AUTH_PROVIDER_ID); const credential = await this.oauthFlowController.complete( this.getAnthropicOAuthFlowOptions(), input.code, @@ -779,10 +813,7 @@ export class ChatService { const authStorage = this.getAuthStorage(); authStorage.reload(); - authStorage.set(ANTHROPIC_AUTH_PROVIDER_ID, { - type: "api_key", - key: apiKey, - }); + authStorage.setStoredApiKey(ANTHROPIC_AUTH_PROVIDER_ID, apiKey); } private applyAnthropicRuntimeEnv(variables: AnthropicEnvVariables): void { diff --git a/packages/chat/src/server/desktop/index.ts b/packages/chat/src/server/desktop/index.ts index 39ee5ab0fb3..06df62dddd6 100644 --- a/packages/chat/src/server/desktop/index.ts +++ b/packages/chat/src/server/desktop/index.ts @@ -16,13 +16,4 @@ export { export { ChatService } from "./chat-service"; export type { ChatServiceRouter } from "./router"; export { createChatServiceRouter } from "./router"; -export type { - SmallModelCredential, - SmallModelProvider, - SmallModelProviderId, -} from "./small-model"; -export { getDefaultSmallModelProviders } from "./small-model"; -export { - generateTitleFromMessage, - generateTitleFromMessageWithStreamingModel, -} from "./title-generation"; +export { generateTitleFromMessage } from "./title-generation"; diff --git a/packages/chat/src/server/desktop/small-model/index.ts b/packages/chat/src/server/desktop/small-model/index.ts deleted file mode 100644 index 70228ebbe29..00000000000 --- a/packages/chat/src/server/desktop/small-model/index.ts +++ /dev/null @@ -1,6 +0,0 @@ -export type { - SmallModelCredential, - SmallModelProvider, - SmallModelProviderId, -} from "./small-model"; -export { getDefaultSmallModelProviders } from "./small-model"; diff --git a/packages/chat/src/server/desktop/small-model/small-model.test.ts b/packages/chat/src/server/desktop/small-model/small-model.test.ts deleted file mode 100644 index 182d1468dea..00000000000 --- a/packages/chat/src/server/desktop/small-model/small-model.test.ts +++ /dev/null @@ -1,391 +0,0 @@ -import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"; - -type MockOpenAICredentials = { - apiKey: string; - kind: "apiKey" | "oauth"; - source: string; - expiresAt?: number; - accountId?: string; - providerId?: string; -}; - -const createAnthropicModelMock = mock(() => "anthropic-default-model"); -let lastCreateOpenAIOptions: { fetch?: typeof fetch } | undefined; -const createOpenAIMock = mock((options?: { fetch?: typeof fetch }) => { - lastCreateOpenAIOptions = options; - return Object.assign(createOpenAIResponsesModelMock, { - chat: createOpenAIChatModelMock, - responses: createOpenAIResponsesModelMock, - }); -}); -const createOpenAIResponsesModelMock = mock( - () => "openai-default-responses-model", -); -const createOpenAIChatModelMock = mock(() => "openai-default-chat-model"); -const getAnthropicCredentialsFromAnySourceMock = mock(() => null); -const getAnthropicProviderOptionsMock = mock(() => ({ apiKey: "unused" })); -const getOpenAICredentialsFromAnySourceMock = mock( - (() => null) as () => MockOpenAICredentials | null, -); -const getOpenAICredentialsFromAuthStorageMock = mock( - (authStorage?: { - reload: () => void; - get: (providerId: string) => - | { - type: "api_key"; - key: string; - } - | { - type: "oauth"; - access: string; - expires?: number; - accountId?: string; - } - | undefined; - }): MockOpenAICredentials | null => { - const storage = authStorage ?? fakeAuthStorage; - storage.reload(); - - const credentials = ["openai-codex", "openai"] - .map((providerId) => { - const credential = storage.get(providerId); - if (!credential) { - return null; - } - - if (credential.type === "api_key" && credential.key.trim()) { - return { - apiKey: credential.key.trim(), - kind: "apiKey" as const, - source: "auth-storage", - providerId, - }; - } - - if (credential.type === "oauth" && credential.access.trim()) { - return { - apiKey: credential.access.trim(), - kind: "oauth" as const, - source: "auth-storage", - expiresAt: credential.expires, - accountId: credential.accountId?.trim() || undefined, - providerId, - }; - } - - return null; - }) - .filter( - (credential): credential is MockOpenAICredentials => - credential !== null, - ); - - return ( - credentials.find( - (credential) => - credential.kind !== "oauth" || - typeof credential.expiresAt !== "number" || - Date.now() < credential.expiresAt, - ) ?? - credentials[0] ?? - null - ); - }, -); -const fakeAuthStorage = { - reload: mock(() => {}), - get: mock(() => undefined), - getApiKey: mock(async () => null), -}; -const originalFetch = globalThis.fetch; -const fetchMock = mock(async () => new Response(null, { status: 200 })); - -mock.module("@ai-sdk/anthropic", () => ({ - createAnthropic: mock(() => createAnthropicModelMock), -})); - -mock.module("@ai-sdk/openai", () => ({ - createOpenAI: createOpenAIMock, -})); - -mock.module("mastracode", () => ({ - createAuthStorage: mock(() => fakeAuthStorage), - createMastraCode: mock(async () => ({ - harness: {}, - mcpManager: null, - hookManager: null, - authStorage: null, - storageWarning: undefined, - })), -})); - -mock.module("../auth/anthropic", () => ({ - getCredentialsFromAnySource: getAnthropicCredentialsFromAnySourceMock, - getCredentialsFromAuthStorage: () => null, - getCredentialsFromConfig: () => null, - getCredentialsFromKeychain: () => null, - getAnthropicProviderOptions: getAnthropicProviderOptionsMock, - isClaudeCredentialExpired: () => false, - createAnthropicOAuthSession: () => {}, - exchangeAnthropicAuthorizationCode: () => {}, -})); - -mock.module("../auth/openai", () => ({ - getOpenAICredentialsFromAnySource: getOpenAICredentialsFromAnySourceMock, - getOpenAICredentialsFromAuthStorage: getOpenAICredentialsFromAuthStorageMock, - isOpenAICredentialExpired: (credential: { - kind: "apiKey" | "oauth"; - expiresAt?: number; - }) => - credential.kind === "oauth" && - typeof credential.expiresAt === "number" && - Date.now() >= credential.expiresAt, -})); - -const { getDefaultSmallModelProviders } = await import("./small-model"); - -describe("getDefaultSmallModelProviders", () => { - beforeEach(() => { - getAnthropicCredentialsFromAnySourceMock.mockReturnValue(null); - getOpenAICredentialsFromAnySourceMock.mockReturnValue(null); - getAnthropicProviderOptionsMock.mockClear(); - createAnthropicModelMock.mockClear(); - createOpenAIMock.mockClear(); - getOpenAICredentialsFromAuthStorageMock.mockClear(); - lastCreateOpenAIOptions = undefined; - createOpenAIResponsesModelMock.mockClear(); - createOpenAIChatModelMock.mockClear(); - fakeAuthStorage.reload.mockClear(); - fakeAuthStorage.get.mockClear(); - fakeAuthStorage.getApiKey.mockClear(); - fakeAuthStorage.get.mockReturnValue(undefined); - fakeAuthStorage.getApiKey.mockResolvedValue(null); - fetchMock.mockClear(); - globalThis.fetch = fetchMock as typeof fetch; - }); - - afterEach(() => { - globalThis.fetch = originalFetch; - }); - - it("uses the OpenAI Codex OAuth model path for OAuth credentials", async () => { - getOpenAICredentialsFromAnySourceMock.mockReturnValue({ - apiKey: "openai-key", - kind: "oauth", - source: "auth-storage", - accountId: "chatgpt-account", - providerId: "openai-codex", - }); - fakeAuthStorage.get.mockReturnValue({ - type: "oauth", - access: "oauth-access-token", - accountId: "chatgpt-account", - }); - - const openAIProvider = getDefaultSmallModelProviders().find( - (provider) => provider.id === "openai", - ); - - expect(openAIProvider).toBeDefined(); - const credentials = openAIProvider?.resolveCredentials(); - expect(credentials).toEqual({ - apiKey: "openai-key", - kind: "oauth", - source: "auth-storage", - accountId: "chatgpt-account", - providerId: "openai-codex", - }); - if (!openAIProvider || !credentials) { - throw new Error("OpenAI provider should resolve OAuth credentials"); - } - - const model = await openAIProvider.createModel(credentials); - - expect(model).toBe("openai-default-responses-model"); - expect(createOpenAIResponsesModelMock).toHaveBeenCalledWith( - "gpt-5.1-codex-mini", - ); - expect(createOpenAIChatModelMock).not.toHaveBeenCalled(); - }); - - it("uses the resolved OpenAI provider id for the OAuth transport", async () => { - getOpenAICredentialsFromAnySourceMock.mockReturnValue({ - apiKey: "legacy-openai-key", - kind: "oauth", - source: "auth-storage", - providerId: "openai", - }); - fakeAuthStorage.get.mockImplementation((providerId: string) => { - if (providerId !== "openai") { - return undefined; - } - - return { - type: "oauth", - access: "legacy-openai-access", - }; - }); - - const openAIProvider = getDefaultSmallModelProviders().find( - (provider) => provider.id === "openai", - ); - if (!openAIProvider) { - throw new Error("OpenAI provider should exist"); - } - - const credentials = openAIProvider.resolveCredentials(); - if (!credentials) { - throw new Error("OpenAI provider should resolve OAuth credentials"); - } - - await openAIProvider.createModel(credentials); - - const oauthFetch = lastCreateOpenAIOptions?.fetch; - if (!oauthFetch) { - throw new Error("OpenAI OAuth provider should pass a fetch override"); - } - await oauthFetch("https://api.openai.com/v1/responses", { - headers: { - Authorization: "Bearer should-be-replaced", - }, - }); - - expect(fakeAuthStorage.get).toHaveBeenCalledWith("openai"); - expect(fakeAuthStorage.get).not.toHaveBeenCalledWith("openai-codex"); - }); - - it("preserves Request details when rewriting the OpenAI OAuth transport", async () => { - getOpenAICredentialsFromAnySourceMock.mockReturnValue({ - apiKey: "openai-key", - kind: "oauth", - source: "auth-storage", - accountId: "chatgpt-account", - providerId: "openai-codex", - }); - fakeAuthStorage.get.mockReturnValue({ - type: "oauth", - access: "oauth-access-token", - accountId: "chatgpt-account", - }); - - const openAIProvider = getDefaultSmallModelProviders().find( - (provider) => provider.id === "openai", - ); - if (!openAIProvider) { - throw new Error("OpenAI provider should exist"); - } - - const credentials = openAIProvider.resolveCredentials(); - if (!credentials) { - throw new Error("OpenAI provider should resolve OAuth credentials"); - } - - await openAIProvider.createModel(credentials); - - const oauthFetch = lastCreateOpenAIOptions?.fetch; - if (!oauthFetch) { - throw new Error("OpenAI OAuth provider should pass a fetch override"); - } - - const abortController = new AbortController(); - const request = new Request("https://api.openai.com/v1/responses", { - method: "POST", - body: JSON.stringify({ prompt: "name this workspace" }), - headers: { - "Content-Type": "application/json", - "X-Test-Header": "present", - Authorization: "Bearer should-be-replaced", - }, - signal: abortController.signal, - }); - - await oauthFetch(request); - - const [forwardedRequest] = fetchMock.mock.calls.at(-1) ?? []; - expect(forwardedRequest).toBeInstanceOf(Request); - if (!(forwardedRequest instanceof Request)) { - throw new Error("fetch should receive a rewritten Request"); - } - - expect(forwardedRequest.url).toBe( - "https://chatgpt.com/backend-api/codex/responses", - ); - expect(forwardedRequest.method).toBe("POST"); - expect(await forwardedRequest.clone().text()).toBe( - JSON.stringify({ prompt: "name this workspace" }), - ); - expect(forwardedRequest.headers.get("content-type")).toBe( - "application/json", - ); - expect(forwardedRequest.headers.get("x-test-header")).toBe("present"); - expect(forwardedRequest.headers.get("authorization")).toBe( - "Bearer oauth-access-token", - ); - expect(forwardedRequest.headers.get("chatgpt-account-id")).toBe( - "chatgpt-account", - ); - expect(forwardedRequest.signal).toBe(abortController.signal); - }); - - it("uses the OpenAI chat model path for API key credentials", async () => { - getOpenAICredentialsFromAnySourceMock.mockReturnValue({ - apiKey: "openai-key", - kind: "apiKey", - source: "auth-storage", - providerId: "openai-codex", - }); - - const openAIProvider = getDefaultSmallModelProviders().find( - (provider) => provider.id === "openai", - ); - - expect(openAIProvider).toBeDefined(); - const credentials = openAIProvider?.resolveCredentials(); - expect(credentials).toEqual({ - apiKey: "openai-key", - kind: "apiKey", - source: "auth-storage", - providerId: "openai-codex", - }); - if (!openAIProvider || !credentials) { - throw new Error("OpenAI provider should resolve API key credentials"); - } - - const model = await openAIProvider.createModel(credentials); - - expect(model).toBe("openai-default-chat-model"); - expect(createOpenAIChatModelMock).toHaveBeenCalledWith("gpt-4o-mini"); - expect(createOpenAIResponsesModelMock).not.toHaveBeenCalled(); - }); - - it("uses the Anthropic provider path for supported credentials", async () => { - getAnthropicCredentialsFromAnySourceMock.mockReturnValue({ - apiKey: "anthropic-key", - kind: "apiKey", - source: "config", - }); - - const anthropicProvider = getDefaultSmallModelProviders().find( - (provider) => provider.id === "anthropic", - ); - - expect(anthropicProvider).toBeDefined(); - const credentials = anthropicProvider?.resolveCredentials(); - expect(credentials).toEqual({ - apiKey: "anthropic-key", - kind: "apiKey", - source: "config", - }); - if (!anthropicProvider || !credentials) { - throw new Error("Anthropic provider should resolve credentials"); - } - - const model = await anthropicProvider.createModel(credentials); - - expect(model).toBe("anthropic-default-model"); - expect(getAnthropicProviderOptionsMock).toHaveBeenCalledWith(credentials); - expect(createAnthropicModelMock).toHaveBeenCalledWith( - "claude-haiku-4-5-20251001", - ); - }); -}); diff --git a/packages/chat/src/server/desktop/small-model/small-model.ts b/packages/chat/src/server/desktop/small-model/small-model.ts deleted file mode 100644 index 43434cbef51..00000000000 --- a/packages/chat/src/server/desktop/small-model/small-model.ts +++ /dev/null @@ -1,146 +0,0 @@ -import { createAnthropic } from "@ai-sdk/anthropic"; -import { createOpenAI } from "@ai-sdk/openai"; -import { createAuthStorage } from "mastracode"; -import { - type ClaudeCredentials, - getCredentialsFromAnySource as getAnthropicCredentialsFromAnySource, - getAnthropicProviderOptions, -} from "../auth/anthropic"; -import { - getOpenAICredentialsFromAnySource, - type OpenAICredentials, -} from "../auth/openai"; -import { OPENAI_AUTH_PROVIDER_ID } from "../auth/provider-ids"; - -export type SmallModelProviderId = "anthropic" | "openai"; - -export interface SmallModelCredential { - apiKey: string; - kind: "apiKey" | "oauth"; - source: string; - expiresAt?: number; - accountId?: string; - providerId?: string; -} - -export interface SmallModelProvider { - id: SmallModelProviderId; - name: string; - resolveCredentials: () => SmallModelCredential | null; - isSupported: (credentials: SmallModelCredential) => { - supported: boolean; - reason?: string; - }; - createModel: ( - credentials: SmallModelCredential, - ) => unknown | Promise; -} - -const OPENAI_CODEX_API_ENDPOINT = - "https://chatgpt.com/backend-api/codex/responses"; -const OPENAI_CODEX_SMALL_MODEL_ID = "gpt-5.1-codex-mini"; -const OPENAI_API_SMALL_MODEL_ID = "gpt-4o-mini"; - -function createOpenAICodexOAuthModel(credentials: OpenAICredentials) { - const authStorage = createAuthStorage(); - const openAIAuthProviderId = - credentials.providerId ?? OPENAI_AUTH_PROVIDER_ID; - const oauthFetchImpl = async ( - url: Parameters[0], - init?: Parameters[1], - ): Promise => { - authStorage.reload(); - const storedCredential = authStorage.get(openAIAuthProviderId); - if (!storedCredential || storedCredential.type !== "oauth") { - throw new Error("Not logged in to OpenAI Codex. Reconnect OpenAI."); - } - - let accessToken = storedCredential.access; - if ( - typeof storedCredential.expires === "number" && - Date.now() >= storedCredential.expires - ) { - const refreshedToken = await authStorage.getApiKey(openAIAuthProviderId); - if (!refreshedToken) { - throw new Error( - "Failed to refresh OpenAI Codex token. Please reconnect OpenAI.", - ); - } - accessToken = refreshedToken; - authStorage.reload(); - } - - const refreshedCredential = authStorage.get(openAIAuthProviderId); - const accountId = - refreshedCredential && - typeof refreshedCredential === "object" && - "accountId" in refreshedCredential && - typeof refreshedCredential.accountId === "string" && - refreshedCredential.accountId.trim().length > 0 - ? refreshedCredential.accountId.trim() - : credentials.accountId?.trim() || undefined; - - const baseRequest = new Request(url, init); - const parsedUrl = new URL(baseRequest.url); - const shouldRewrite = - parsedUrl.pathname.includes("/v1/responses") || - parsedUrl.pathname.includes("/chat/completions"); - const outgoingRequest = new Request( - shouldRewrite ? OPENAI_CODEX_API_ENDPOINT : baseRequest.url, - baseRequest, - ); - const headers = new Headers(outgoingRequest.headers); - headers.delete("authorization"); - headers.set("Authorization", `Bearer ${accessToken}`); - if (accountId) { - headers.set("ChatGPT-Account-Id", accountId); - } - - return fetch( - new Request(outgoingRequest, { - headers, - }), - ); - }; - const bunFetch = globalThis.fetch as typeof fetch & { - preconnect?: typeof globalThis.fetch; - }; - const oauthFetch = Object.assign( - oauthFetchImpl, - typeof bunFetch.preconnect === "function" - ? { preconnect: bunFetch.preconnect.bind(globalThis.fetch) } - : {}, - ) as typeof fetch; - - return createOpenAI({ - apiKey: "oauth-dummy-key", - fetch: oauthFetch, - }).responses(OPENAI_CODEX_SMALL_MODEL_ID); -} - -export function getDefaultSmallModelProviders(): SmallModelProvider[] { - return [ - { - id: "anthropic", - name: "Anthropic", - resolveCredentials: () => getAnthropicCredentialsFromAnySource(), - isSupported: () => ({ supported: true }), - createModel: (credentials) => - createAnthropic( - getAnthropicProviderOptions(credentials as ClaudeCredentials), - )("claude-haiku-4-5-20251001"), - }, - { - id: "openai", - name: "OpenAI", - resolveCredentials: () => getOpenAICredentialsFromAnySource(), - isSupported: () => ({ supported: true }), - createModel: (credentials) => - credentials.kind === "oauth" - ? createOpenAICodexOAuthModel(credentials as OpenAICredentials) - : createOpenAI({ apiKey: credentials.apiKey }).chat( - OPENAI_API_SMALL_MODEL_ID, - ), - }, - ]; -} diff --git a/packages/chat/src/server/desktop/title-generation/index.ts b/packages/chat/src/server/desktop/title-generation/index.ts index 9a47fa5ef81..5dcc0b3bf42 100644 --- a/packages/chat/src/server/desktop/title-generation/index.ts +++ b/packages/chat/src/server/desktop/title-generation/index.ts @@ -1,4 +1 @@ -export { - generateTitleFromMessage, - generateTitleFromMessageWithStreamingModel, -} from "./title-generation"; +export { generateTitleFromMessage } from "./title-generation"; diff --git a/packages/chat/src/server/desktop/title-generation/title-generation.test.ts b/packages/chat/src/server/desktop/title-generation/title-generation.test.ts deleted file mode 100644 index f58d11b8206..00000000000 --- a/packages/chat/src/server/desktop/title-generation/title-generation.test.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { describe, expect, it, mock } from "bun:test"; - -const streamTextMock = mock(() => ({ - text: Promise.resolve(" Checking In "), -})); - -mock.module("ai", () => ({ - streamText: streamTextMock, -})); - -const { generateTitleFromMessageWithStreamingModel } = await import( - "./title-generation" -); - -describe("generateTitleFromMessageWithStreamingModel", () => { - it("streams a title with Codex-compatible provider options", async () => { - const title = await generateTitleFromMessageWithStreamingModel({ - message: " hey boss how are you ", - model: { id: "test-model" } as never, - instructions: "You generate concise workspace titles.", - }); - - expect(title).toBe("Checking In"); - expect(streamTextMock).toHaveBeenCalledWith({ - model: { id: "test-model" }, - system: "You generate concise workspace titles.", - prompt: - "Return only a short title for this user message:\nhey boss how are you", - providerOptions: { - openai: { - instructions: "You generate concise workspace titles.", - store: false, - }, - }, - }); - }); -}); diff --git a/packages/chat/src/server/desktop/title-generation/title-generation.ts b/packages/chat/src/server/desktop/title-generation/title-generation.ts index 25073636d2c..94f3ab23dfe 100644 --- a/packages/chat/src/server/desktop/title-generation/title-generation.ts +++ b/packages/chat/src/server/desktop/title-generation/title-generation.ts @@ -1,5 +1,3 @@ -import { type LanguageModel, streamText } from "ai"; - type TitleModel = unknown; type TitleAgent = { generateTitleFromUserMessage: (args: { @@ -71,29 +69,3 @@ export async function generateTitleFromMessage( return title?.trim() || null; } - -export async function generateTitleFromMessageWithStreamingModel(params: { - message: string; - model: LanguageModel; - instructions?: string; -}): Promise { - const cleanedMessage = params.message.trim(); - if (!cleanedMessage) { - return null; - } - - const instructions = params.instructions ?? "You generate concise titles."; - const result = streamText({ - model: params.model, - system: instructions, - prompt: `Return only a short title for this user message:\n${cleanedMessage}`, - providerOptions: { - openai: { - instructions, - store: false, - }, - }, - }); - - return (await result.text).trim() || null; -} diff --git a/packages/chat/src/server/shared/index.ts b/packages/chat/src/server/shared/index.ts new file mode 100644 index 00000000000..c54241172f5 --- /dev/null +++ b/packages/chat/src/server/shared/index.ts @@ -0,0 +1,6 @@ +export { + getSmallModel, + getSmallModelCandidates, + type SmallModelCandidate, + type SmallModelProviderId, +} from "./small-model"; diff --git a/packages/chat/src/server/shared/small-model/get-small-model.ts b/packages/chat/src/server/shared/small-model/get-small-model.ts new file mode 100644 index 00000000000..8810b06fe95 --- /dev/null +++ b/packages/chat/src/server/shared/small-model/get-small-model.ts @@ -0,0 +1,349 @@ +import { existsSync, readFileSync } from "node:fs"; +import { homedir } from "node:os"; +import { join } from "node:path"; +import { createAnthropic } from "@ai-sdk/anthropic"; +import { createOpenAI } from "@ai-sdk/openai"; +import { createAuthStorage } from "mastracode"; +import { + type ClaudeCredentials, + getCredentialsFromConfig as getAnthropicCredentialsFromConfig, + getCredentialsFromKeychain as getAnthropicCredentialsFromKeychain, + getAnthropicProviderOptions, + isClaudeCredentialExpired, +} from "../../desktop/auth/anthropic"; +import { + getOpenAICredentialsFromAnySource, + isOpenAICredentialExpired, + type OpenAICredentials, +} from "../../desktop/auth/openai"; +import { OPENAI_AUTH_PROVIDER_ID } from "../../desktop/auth/provider-ids"; +import { parseAnthropicEnvText } from "../../desktop/chat-service/anthropic-env-config"; + +const ANTHROPIC_SMALL_MODEL_ID = "claude-haiku-4-5-20251001"; +const OPENAI_API_SMALL_MODEL_ID = "gpt-4o-mini"; +const OPENAI_CODEX_SMALL_MODEL_ID = "gpt-5.1-codex-mini"; +const OPENAI_CODEX_API_ENDPOINT = + "https://chatgpt.com/backend-api/codex/responses"; + +export type SmallModelProviderId = "anthropic" | "openai"; + +export interface SmallModelCandidate { + providerId: SmallModelProviderId; + providerName: string; + credentialKind: "apiKey" | "oauth"; + credentialSource: string; + createModel: () => unknown; +} + +/** + * FORK NOTE: ported from upstream #3517's `getSmallModel()` but rebuilt + * on top of fork's credential resolvers so it still honors: + * - Anthropic OAuth (claude-code-20250219 / oauth-2025-04-20 headers via + * getAnthropicProviderOptions — upstream lost this when it switched to + * apiKey-only resolution) + * - Anthropic managed env config (~/.superset/chat-anthropic-env.json + * with ANTHROPIC_API_KEY or ANTHROPIC_AUTH_TOKEN; AUTH_TOKEN is + * routed through the OAuth header path, not apiKey) + * - OpenAI Codex OAuth (custom fetch that rewrites to the Codex + * backend endpoint and refreshes access tokens via mastracode) + * - OpenAI API key in mastracode AuthStorage's `openai-codex` slot + * + * Upstream's version collapsed credentials to apiKey-only. We keep the + * simpler `getSmallModel()` export for upstream-compatible callers + * (runtime.ts title generation) and add `getSmallModelCandidates()` so + * the fork callSmallModel shim can iterate providers in order and + * record attempts properly (restoring provider fallback behavior). + */ +function buildCandidates(): SmallModelCandidate[] { + const candidates: SmallModelCandidate[] = []; + + const envApiKey = process.env.ANTHROPIC_API_KEY?.trim(); + if (envApiKey) { + candidates.push({ + providerId: "anthropic", + providerName: "Anthropic", + credentialKind: "apiKey", + credentialSource: "env:ANTHROPIC_API_KEY", + createModel: () => + createAnthropic({ apiKey: envApiKey })(ANTHROPIC_SMALL_MODEL_ID), + }); + } + + const anthropicStored = resolveAnthropicCredentialsSync(); + if (anthropicStored) { + candidates.push({ + providerId: "anthropic", + providerName: "Anthropic", + credentialKind: anthropicStored.kind === "oauth" ? "oauth" : "apiKey", + credentialSource: anthropicStored.source, + createModel: () => + createAnthropic(getAnthropicProviderOptions(anthropicStored))( + ANTHROPIC_SMALL_MODEL_ID, + ), + }); + } + + const anthropicEnvConfigCred = resolveAnthropicEnvConfigCredential(); + if (anthropicEnvConfigCred) { + candidates.push({ + providerId: "anthropic", + providerName: "Anthropic", + credentialKind: + anthropicEnvConfigCred.kind === "oauth" ? "oauth" : "apiKey", + credentialSource: anthropicEnvConfigCred.source, + createModel: () => + createAnthropic(getAnthropicProviderOptions(anthropicEnvConfigCred))( + ANTHROPIC_SMALL_MODEL_ID, + ), + }); + } + + const envOpenAIKey = process.env.OPENAI_API_KEY?.trim(); + if (envOpenAIKey) { + candidates.push({ + providerId: "openai", + providerName: "OpenAI", + credentialKind: "apiKey", + credentialSource: "env:OPENAI_API_KEY", + createModel: () => + createOpenAI({ apiKey: envOpenAIKey }).chat(OPENAI_API_SMALL_MODEL_ID), + }); + } + + const openaiCreds = getOpenAICredentialsFromAnySource(); + if (openaiCreds && !isOpenAICredentialExpired(openaiCreds)) { + candidates.push({ + providerId: "openai", + providerName: "OpenAI", + credentialKind: openaiCreds.kind === "oauth" ? "oauth" : "apiKey", + credentialSource: openaiCreds.source, + createModel: () => + openaiCreds.kind === "oauth" + ? createOpenAICodexOAuthModel(openaiCreds) + : createOpenAI({ apiKey: openaiCreds.apiKey }).chat( + OPENAI_API_SMALL_MODEL_ID, + ), + }); + } + + return candidates; +} + +export function getSmallModelCandidates(): SmallModelCandidate[] { + return buildCandidates(); +} + +/** + * Returns the first viable small-model AI-SDK LanguageModel or null. + * Upstream-compatible surface for simple single-model callers + * (runtime.ts title generation, ai-name.ts workspace naming). + * + * Iterates every candidate and returns the first one whose + * `createModel()` does not throw, so a broken-but-listed credential + * (e.g. stale cached account id) doesn't block the next provider. + * Runtime-level failures (expired OAuth 401, rate limits) still need + * to be handled by the caller — those surface when the returned + * model is actually invoked, not when it's constructed. + */ +export function getSmallModel(): unknown | null { + for (const candidate of buildCandidates()) { + try { + return candidate.createModel(); + } catch { + // Try the next candidate. + } + } + return null; +} + +// ---- Anthropic credential resolution helpers ------------------------------- + +/** + * Synchronous Anthropic credential resolver. Fork's + * `getCredentialsFromAnySource` is async because it may kick a + * mastracode token refresh. For the small-model candidate list we need + * a sync decision, so we stick to synchronous sources (config file, + * keychain, auth-storage main slot). If the resulting OAuth token is + * actually expired, createAnthropic will 401 and the shim falls + * through to the next candidate. + */ +function resolveAnthropicCredentialsSync(): ClaudeCredentials | null { + // Walk the sync sources in priority order and return the first + // non-expired credential. Unlike getCredentialsFromAnySource() we do + // NOT fall back to a known-expired credential at the end — expired + // OAuth tokens would poison buildCandidates() and block the later + // env-config / OpenAI candidates, which matter for getSmallModel()'s + // direct callers where we can't retry after a 401. + const sources: Array<() => ClaudeCredentials | null> = [ + () => { + try { + return getAnthropicCredentialsFromConfig(); + } catch { + return null; + } + }, + () => { + try { + return getAnthropicCredentialsFromKeychain(); + } catch { + return null; + } + }, + () => resolveAnthropicFromStoreSync(), + ]; + for (const resolve of sources) { + const credential = resolve(); + if (!credential) continue; + if (!isClaudeCredentialExpired(credential)) return credential; + } + return null; +} + +function resolveAnthropicFromStoreSync(): ClaudeCredentials | null { + try { + const storage = createAuthStorage(); + storage.reload(); + const raw = storage.get("anthropic"); + if (!raw || typeof raw !== "object") return null; + const value = raw as Record; + if ( + value.type === "api_key" && + typeof value.key === "string" && + value.key.trim().length > 0 + ) { + return { + apiKey: value.key.trim(), + source: "auth-storage", + kind: "apiKey", + }; + } + if ( + value.type === "oauth" && + typeof value.access === "string" && + value.access.trim().length > 0 + ) { + return { + apiKey: value.access.trim(), + source: "auth-storage", + kind: "oauth", + expiresAt: + typeof value.expires === "number" ? value.expires : undefined, + }; + } + } catch { + // Fall through to null. + } + return null; +} + +function resolveAnthropicEnvConfigCredential(): ClaudeCredentials | null { + try { + const supersetHome = + process.env.SUPERSET_HOME_DIR?.trim() || join(homedir(), ".superset"); + const path = join(supersetHome, "chat-anthropic-env.json"); + if (!existsSync(path)) return null; + const parsed = JSON.parse(readFileSync(path, "utf-8")) as { + envText?: string; + }; + if (typeof parsed.envText !== "string") return null; + const variables = parseAnthropicEnvText(parsed.envText); + const apiKey = variables.ANTHROPIC_API_KEY?.trim(); + if (apiKey) { + // `source: "config"` keeps us inside fork's ClaudeCredentials + // union; the actual display label comes from + // SmallModelCandidate.credentialSource below. + return { apiKey, source: "config", kind: "apiKey" }; + } + const authToken = variables.ANTHROPIC_AUTH_TOKEN?.trim(); + if (authToken) { + // FORK NOTE: AUTH_TOKEN must flow through the OAuth path + // (authToken + anthropic-beta / x-app headers) — routing it + // through `apiKey` was the original PR #313 regression. + return { apiKey: authToken, source: "config", kind: "oauth" }; + } + } catch { + // Swallow — missing / malformed config falls back to other sources. + } + return null; +} + +// ---- OpenAI Codex OAuth model ---------------------------------------------- + +function createOpenAICodexOAuthModel(credentials: OpenAICredentials) { + const authStorage = createAuthStorage(); + const openAIAuthProviderId = + credentials.providerId ?? OPENAI_AUTH_PROVIDER_ID; + const oauthFetchImpl = async ( + url: Parameters[0], + init?: Parameters[1], + ): Promise => { + authStorage.reload(); + const storedCredential = authStorage.get(openAIAuthProviderId); + if (!storedCredential || storedCredential.type !== "oauth") { + throw new Error("Not logged in to OpenAI Codex. Reconnect OpenAI."); + } + + let accessToken = storedCredential.access; + if ( + typeof storedCredential.expires === "number" && + Date.now() >= storedCredential.expires + ) { + const refreshedToken = await authStorage.getApiKey(openAIAuthProviderId); + if (!refreshedToken) { + throw new Error( + "Failed to refresh OpenAI Codex token. Please reconnect OpenAI.", + ); + } + accessToken = refreshedToken; + authStorage.reload(); + } + + const refreshedCredential = authStorage.get(openAIAuthProviderId); + const accountId = + refreshedCredential && + typeof refreshedCredential === "object" && + "accountId" in refreshedCredential && + typeof refreshedCredential.accountId === "string" && + refreshedCredential.accountId.trim().length > 0 + ? refreshedCredential.accountId.trim() + : credentials.accountId?.trim() || undefined; + + // biome-ignore-start lint/suspicious/noExplicitAny: fetch signature varies across runtimes (bun vs. node vs. electron) and the cross-package typecheck context loses the DOM Request type overloads. + const baseRequest = new Request(url as any, init as any); + // biome-ignore-end lint/suspicious/noExplicitAny: matching pair + const parsedUrl = new URL(baseRequest.url); + const shouldRewrite = + parsedUrl.pathname.includes("/v1/responses") || + parsedUrl.pathname.includes("/chat/completions"); + const outgoingRequest = new Request( + shouldRewrite ? OPENAI_CODEX_API_ENDPOINT : baseRequest.url, + baseRequest, + ); + const headers = new Headers(outgoingRequest.headers); + headers.delete("authorization"); + headers.set("Authorization", `Bearer ${accessToken}`); + if (accountId) { + headers.set("ChatGPT-Account-Id", accountId); + } + + return fetch( + new Request(outgoingRequest, { + headers, + }), + ); + }; + const bunFetch = globalThis.fetch as typeof fetch & { + preconnect?: typeof globalThis.fetch; + }; + const oauthFetch = Object.assign( + oauthFetchImpl, + typeof bunFetch.preconnect === "function" + ? { preconnect: bunFetch.preconnect.bind(globalThis.fetch) } + : {}, + ) as typeof fetch; + + return createOpenAI({ + apiKey: "oauth-dummy-key", + fetch: oauthFetch, + }).responses(OPENAI_CODEX_SMALL_MODEL_ID); +} diff --git a/packages/chat/src/server/shared/small-model/index.ts b/packages/chat/src/server/shared/small-model/index.ts new file mode 100644 index 00000000000..d2b53f46e8e --- /dev/null +++ b/packages/chat/src/server/shared/small-model/index.ts @@ -0,0 +1,6 @@ +export { + getSmallModel, + getSmallModelCandidates, + type SmallModelCandidate, + type SmallModelProviderId, +} from "./get-small-model"; diff --git a/packages/chat/src/server/trpc/utils/runtime/runtime.test.ts b/packages/chat/src/server/trpc/utils/runtime/runtime.test.ts index 5a15c0534bf..fdac90bd49b 100644 --- a/packages/chat/src/server/trpc/utils/runtime/runtime.test.ts +++ b/packages/chat/src/server/trpc/utils/runtime/runtime.test.ts @@ -1,33 +1,27 @@ import { describe, expect, it, mock } from "bun:test"; import type { RuntimeSession } from "./runtime"; -let generateTitleFromMessageWithStreamingModelResult = ""; - -const generateTitleFromMessageWithStreamingModelMock = mock( - (async (_params: { message: string; model: unknown }) => - generateTitleFromMessageWithStreamingModelResult) as ( +let generateTitleFromMessageResult = ""; + +const generateTitleFromMessageMock = mock( + (async (_params: { + message: string; + agentModel?: unknown; + agent?: unknown; + modelId?: string; + }) => generateTitleFromMessageResult) as ( args: unknown, - ) => Promise, + ) => Promise, ); -const getDefaultSmallModelProvidersMock = mock(() => [ - { - id: "mock", - name: "Mock", - resolveCredentials: () => ({ - apiKey: "test", - kind: "apiKey", - source: "test", - }), - isSupported: () => ({ supported: true }), - createModel: () => ({}), - }, -]); +const getSmallModelMock = mock(() => ({}) as unknown); mock.module("../../../desktop", () => ({ - generateTitleFromMessageWithStreamingModel: - generateTitleFromMessageWithStreamingModelMock, - getDefaultSmallModelProviders: getDefaultSmallModelProvidersMock, + generateTitleFromMessage: generateTitleFromMessageMock, +})); + +mock.module("../../../shared/small-model", () => ({ + getSmallModel: getSmallModelMock, })); const { @@ -96,7 +90,7 @@ function createRuntimeForTitleTest(options?: { const generatedTitle = options?.generatedTitle ?? ""; // Set the mock return value for this test - generateTitleFromMessageWithStreamingModelResult = generatedTitle; + generateTitleFromMessageResult = generatedTitle; const runtime: RuntimeSession = { sessionId: "11111111-1111-1111-1111-111111111111", diff --git a/packages/chat/src/server/trpc/utils/runtime/runtime.ts b/packages/chat/src/server/trpc/utils/runtime/runtime.ts index b32b1aea23e..e4a012ed135 100644 --- a/packages/chat/src/server/trpc/utils/runtime/runtime.ts +++ b/packages/chat/src/server/trpc/utils/runtime/runtime.ts @@ -1,10 +1,8 @@ import type { AppRouter } from "@superset/trpc"; import type { createTRPCClient } from "@trpc/client"; import type { createMastraCode } from "mastracode"; -import { - generateTitleFromMessageWithStreamingModel, - getDefaultSmallModelProviders, -} from "../../../desktop"; +import { generateTitleFromMessage } from "../../../desktop"; +import { getSmallModel } from "../../../shared/small-model"; import type { ThinkingLevel } from "../../zod"; const SUBAGENT_AGENT_TYPES = ["explore", "plan", "execute"] as const; @@ -512,31 +510,21 @@ export async function generateAndSetTitle( // Use a small model for title generation instead of the chat model, // because the chat model may use OAuth auth that isn't accessible via // process.env API keys (e.g. OpenAI Codex OAuth). - const providers = getDefaultSmallModelProviders(); - for (const provider of providers) { - const creds = provider.resolveCredentials(); - if (!creds) continue; - const { supported } = provider.isSupported(creds); - if (!supported) continue; - try { - const model = await provider.createModel(creds); - const title = await generateTitleFromMessageWithStreamingModel({ - message: text, - model: model as import("ai").LanguageModel, - }); - if (!title?.trim()) return; - - await apiClient.chat.updateTitle.mutate({ - sessionId: runtime.sessionId, - title: title.trim(), - }); - return; - } catch (error) { - console.warn( - `[chat] Title generation failed with ${provider.id}, trying next provider:`, - error, - ); - } + const model = getSmallModel(); + if (!model) return; + try { + const title = await generateTitleFromMessage({ + message: text, + agentModel: model, + }); + if (!title?.trim()) return; + + await apiClient.chat.updateTitle.mutate({ + sessionId: runtime.sessionId, + title: title.trim(), + }); + } catch (error) { + console.warn("[chat] Title generation failed:", error); } } catch (error) { console.warn("[chat] Title generation failed:", error); diff --git a/packages/host-service/package.json b/packages/host-service/package.json index f3c646681e4..79d5371e02a 100644 --- a/packages/host-service/package.json +++ b/packages/host-service/package.json @@ -45,6 +45,7 @@ "@hono/node-ws": "^1.3.0", "@hono/trpc-server": "^0.3.4", "@octokit/rest": "^22.0.1", + "@superset/chat": "workspace:*", "@superset/shared": "workspace:*", "@superset/trpc": "workspace:*", "@superset/workspace-fs": "workspace:*", @@ -54,7 +55,7 @@ "better-sqlite3": "12.6.2", "drizzle-orm": "0.45.1", "hono": "^4.8.5", - "mastracode": "0.9.2", + "mastracode": "0.14.0", "node-pty": "1.1.0", "simple-git": "^3.30.0", "superjson": "^2.2.5", diff --git a/packages/host-service/src/providers/model-providers/LocalModelProvider/LocalModelProvider.ts b/packages/host-service/src/providers/model-providers/LocalModelProvider/LocalModelProvider.ts index a620b6b4277..e57bceef5ee 100644 --- a/packages/host-service/src/providers/model-providers/LocalModelProvider/LocalModelProvider.ts +++ b/packages/host-service/src/providers/model-providers/LocalModelProvider/LocalModelProvider.ts @@ -30,12 +30,12 @@ export class LocalModelProvider implements ModelProviderRuntimeResolver { this.anthropicEnvConfigPath = options?.anthropicEnvConfigPath; } - private resolveRuntimeEnv(): { + private async resolveRuntimeEnv(): Promise<{ env: Record; cleanupKeys: string[]; hasUsableRuntimeEnv: boolean; - } { - const anthropicCredential = resolveAnthropicCredential(); + }> { + const anthropicCredential = await resolveAnthropicCredential(); const openaiCredential = resolveOpenAICredential(); const anthropicEnvConfig = getAnthropicEnvConfig({ configPath: this.anthropicEnvConfigPath, @@ -54,11 +54,11 @@ export class LocalModelProvider implements ModelProviderRuntimeResolver { } async hasUsableRuntimeEnv(): Promise { - return this.resolveRuntimeEnv().hasUsableRuntimeEnv; + return (await this.resolveRuntimeEnv()).hasUsableRuntimeEnv; } async prepareRuntimeEnv(): Promise { - const runtimeEnv = this.resolveRuntimeEnv(); + const runtimeEnv = await this.resolveRuntimeEnv(); this.currentRuntimeEnv = applyRuntimeEnv( runtimeEnv.env, runtimeEnv.cleanupKeys, diff --git a/packages/host-service/src/providers/model-providers/LocalModelProvider/utils/resolveAnthropicCredential.ts b/packages/host-service/src/providers/model-providers/LocalModelProvider/utils/resolveAnthropicCredential.ts index 96cdb2c68c6..354ce598f82 100644 --- a/packages/host-service/src/providers/model-providers/LocalModelProvider/utils/resolveAnthropicCredential.ts +++ b/packages/host-service/src/providers/model-providers/LocalModelProvider/utils/resolveAnthropicCredential.ts @@ -82,7 +82,7 @@ function getAnthropicCredentialFromKeychain(): LocalResolvedCredential | null { return null; } -function getAnthropicCredentialFromAuthStorage(): LocalResolvedCredential | null { +async function getAnthropicCredentialFromAuthStorage(): Promise { try { const authStorage = createAuthStorage(); authStorage.reload(); @@ -97,18 +97,45 @@ function getAnthropicCredentialFromAuthStorage(): LocalResolvedCredential | null return { kind: "api_key" }; } - if ( - credential.type === "oauth" && - typeof credential.access === "string" && - credential.access.trim().length > 0 - ) { - return { - kind: "oauth", - expiresAt: - typeof credential.expires === "number" - ? credential.expires - : undefined, - }; + if (credential.type === "oauth") { + const expiresAt = + typeof credential.expires === "number" ? credential.expires : undefined; + if (typeof expiresAt === "number" && Date.now() >= expiresAt) { + try { + await authStorage.getApiKey(ANTHROPIC_PROVIDER_ID); + authStorage.reload(); + const refreshed = authStorage.get(ANTHROPIC_PROVIDER_ID); + if ( + isObjectRecord(refreshed) && + refreshed.type === "oauth" && + typeof refreshed.access === "string" && + refreshed.access.trim().length > 0 + ) { + return { + kind: "oauth", + expiresAt: + typeof refreshed.expires === "number" + ? refreshed.expires + : undefined, + }; + } + // Refresh returned no usable access token — callers must + // fall back rather than proxying an expired credential. + return null; + } catch (error) { + console.warn( + "[LocalModelProvider] Anthropic OAuth refresh failed:", + error, + ); + return null; + } + } + if ( + typeof credential.access === "string" && + credential.access.trim().length > 0 + ) { + return { kind: "oauth", expiresAt }; + } } } catch { // Ignore auth storage read failures for now. @@ -117,10 +144,10 @@ function getAnthropicCredentialFromAuthStorage(): LocalResolvedCredential | null return null; } -export function resolveAnthropicCredential(): LocalResolvedCredential | null { +export async function resolveAnthropicCredential(): Promise { return ( getAnthropicCredentialFromConfig() ?? getAnthropicCredentialFromKeychain() ?? - getAnthropicCredentialFromAuthStorage() + (await getAnthropicCredentialFromAuthStorage()) ); } diff --git a/packages/host-service/src/trpc/router/workspace-creation/utils/ai-branch-name.ts b/packages/host-service/src/trpc/router/workspace-creation/utils/ai-branch-name.ts new file mode 100644 index 00000000000..b2bbec615af --- /dev/null +++ b/packages/host-service/src/trpc/router/workspace-creation/utils/ai-branch-name.ts @@ -0,0 +1,55 @@ +import { generateTitleFromMessage } from "@superset/chat/server/desktop"; +import { getSmallModel } from "@superset/chat/server/shared"; +import { deduplicateBranchName } from "./sanitize-branch"; + +const BRANCH_NAME_INSTRUCTIONS = + "Generate a concise git branch name (2-4 words, kebab-case, descriptive). Return ONLY the branch name, nothing else."; + +const MAX_BRANCH_LENGTH = 100; + +/** + * Light sanitizer for AI-generated branch names — lowercase, kebab-case, + * restricted character set. Differs from desktop's full sanitizer: no + * multi-segment support (AI generates a single segment) and no preserve-case + * options. + */ +function sanitizeGeneratedBranchName(raw: string): string { + return raw + .toLowerCase() + .trim() + .replace(/\s+/g, "-") + .replace(/[^a-z0-9._+@-]/g, "") + .replace(/\.{2,}/g, ".") + .replace(/-+/g, "-") + .replace(/\.lock$/g, "") + .slice(0, MAX_BRANCH_LENGTH) + .replace(/^[-.]+|[-.]+$/g, ""); +} + +export async function generateBranchNameFromPrompt( + prompt: string, + existingBranches: string[], +): Promise { + const model = getSmallModel(); + if (!model) return null; + + let generated: string | null; + try { + generated = await generateTitleFromMessage({ + message: prompt, + agentModel: model, + agentId: "branch-namer", + agentName: "Branch Namer", + instructions: BRANCH_NAME_INSTRUCTIONS, + tracingContext: { surface: "host-service-branch-name" }, + }); + } catch (error) { + console.warn("[generateBranchNameFromPrompt] generation failed:", error); + return null; + } + + if (!generated) return null; + const sanitized = sanitizeGeneratedBranchName(generated); + if (!sanitized) return null; + return deduplicateBranchName(sanitized, existingBranches); +} diff --git a/packages/host-service/src/trpc/router/workspace-creation/workspace-creation.ts b/packages/host-service/src/trpc/router/workspace-creation/workspace-creation.ts index 480b7126a15..1f00c99608f 100644 --- a/packages/host-service/src/trpc/router/workspace-creation/workspace-creation.ts +++ b/packages/host-service/src/trpc/router/workspace-creation/workspace-creation.ts @@ -16,6 +16,7 @@ import { createSimpleGitWithEnv } from "../../../runtime/git/simple-git"; import { createTerminalSessionInternal } from "../../../terminal/terminal"; import type { HostServiceContext } from "../../../types"; import { protectedProcedure, router } from "../../index"; +import { generateBranchNameFromPrompt } from "./utils/ai-branch-name"; import { execGh } from "./utils/exec-gh"; import { derivePrLocalBranchName } from "./utils/pr-branch-name"; import { resolveStartPoint } from "./utils/resolve-start-point"; @@ -631,6 +632,28 @@ export const workspaceCreationRouter = router({ return { defaultBranch, items, nextCursor }; }), + generateBranchName: protectedProcedure + .input(z.object({ projectId: z.string(), prompt: z.string() })) + .mutation(async ({ ctx, input }) => { + const trimmed = input.prompt.trim(); + if (!trimmed) return { branchName: null }; + + const localProject = ctx.db.query.projects + .findFirst({ where: eq(projects.id, input.projectId) }) + .sync(); + if (!localProject) return { branchName: null }; + + const existingBranches = await listBranchNames( + ctx, + localProject.repoPath, + ); + const branchName = await generateBranchNameFromPrompt( + trimmed, + existingBranches, + ); + return { branchName }; + }), + /** * Create a new workspace. Always creates — never opens an existing one. * Branch name is sanitized and deduplicated server-side.