diff --git a/packages/opencode/src/provider/models.ts b/packages/opencode/src/provider/models.ts index bae33178467..99fef9749c3 100644 --- a/packages/opencode/src/provider/models.ts +++ b/packages/opencode/src/provider/models.ts @@ -11,6 +11,19 @@ import { Filesystem } from "../util/filesystem" // Falls back to undefined in dev mode when snapshot doesn't exist /* @ts-ignore */ +// Cache format types for prompt caching +export const CacheFormat = z.enum(["anthropic", "openrouter", "bedrock", "openaiCompatible"]) +export type CacheFormat = z.infer + +export const Caching = z.union([ + z.boolean(), + z.object({ + format: CacheFormat.optional(), + positions: z.array(z.enum(["system", "first", "last"])).optional(), + }), +]) +export type Caching = z.infer + export namespace ModelsDev { const log = Log.create({ service: "models.dev" }) const filepath = path.join(Global.Path.cache, "models.json") @@ -67,6 +80,7 @@ export namespace ModelsDev { headers: z.record(z.string(), z.string()).optional(), provider: z.object({ npm: z.string().optional(), api: z.string().optional() }).optional(), variants: z.record(z.string(), z.record(z.string(), z.any())).optional(), + caching: Caching.optional(), }) export type Model = z.infer diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 349073197d7..d363cfc4220 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -9,7 +9,7 @@ import { BunProc } from "../bun" import { Hash } from "../util/hash" import { Plugin } from "../plugin" import { NamedError } from "@opencode-ai/util/error" -import { ModelsDev } from "./models" +import { ModelsDev, Caching } from "./models" import { Auth } from "../auth" import { Env } from "../env" import { Instance } from "../project/instance" @@ -735,6 +735,7 @@ export namespace Provider { headers: z.record(z.string(), z.string()), release_date: z.string(), variants: z.record(z.string(), z.record(z.string(), z.any())).optional(), + caching: Caching.optional(), }) .meta({ ref: "Model", @@ -816,6 +817,7 @@ export namespace Provider { }, release_date: model.release_date, variants: {}, + caching: model.caching, } m.variants = mapValues(ProviderTransform.variants(m), (v) => v) diff --git a/packages/opencode/src/provider/transform.ts b/packages/opencode/src/provider/transform.ts index 05b9f031fe6..85eb56fb96b 100644 --- a/packages/opencode/src/provider/transform.ts +++ b/packages/opencode/src/provider/transform.ts @@ -172,9 +172,50 @@ export namespace ProviderTransform { } function applyCaching(msgs: ModelMessage[], model: Provider.Model): ModelMessage[] { - const system = msgs.filter((msg) => msg.role === "system").slice(0, 2) - const final = msgs.filter((msg) => msg.role !== "system").slice(-2) + // Determine cache format from model.caching config or infer from provider + const npm = model.api.npm + const providerID = model.providerID + + // Get format from explicit config or infer from provider + let format: "anthropic" | "openrouter" | "bedrock" | "openaiCompatible" | undefined + if (model.caching && typeof model.caching === "object" && model.caching.format) { + format = model.caching.format + } else if (npm === "@ai-sdk/amazon-bedrock" || providerID.includes("bedrock")) { + format = "bedrock" + } else if (npm === "@ai-sdk/anthropic" || providerID === "anthropic") { + format = "anthropic" + } else if (npm === "@openrouter/ai-sdk-provider" || providerID === "openrouter") { + format = "openrouter" + } else { + // Default to openaiCompatible for other providers (kiro-gateway, etc.) + format = "openaiCompatible" + } + + // Determine positions to cache + let positions: ("system" | "first" | "last")[] = ["system", "last"] + if (model.caching && typeof model.caching === "object" && model.caching.positions) { + positions = model.caching.positions + } + + // Select messages to cache based on positions + const messagesToCache: ModelMessage[] = [] + const systemMsgs = msgs.filter((msg) => msg.role === "system") + const nonSystemMsgs = msgs.filter((msg) => msg.role !== "system") + + if (positions.includes("system")) { + messagesToCache.push(...systemMsgs.slice(0, 2)) + } + if (positions.includes("first") && nonSystemMsgs.length > 0) { + messagesToCache.push(nonSystemMsgs[0]) + } + if (positions.includes("last") && nonSystemMsgs.length > 0) { + const lastMsg = nonSystemMsgs[nonSystemMsgs.length - 1] + if (!messagesToCache.includes(lastMsg)) { + messagesToCache.push(lastMsg) + } + } + // Build provider options for all formats (SDK will pick the right one) const providerOptions = { anthropic: { cacheControl: { type: "ephemeral" }, @@ -188,13 +229,13 @@ export namespace ProviderTransform { openaiCompatible: { cache_control: { type: "ephemeral" }, }, - copilot: { - copilot_cache_control: { type: "ephemeral" }, - }, } - for (const msg of unique([...system, ...final])) { - const useMessageLevelOptions = model.providerID === "anthropic" || model.providerID.includes("bedrock") + // Determine if we should use message-level or content-level options + // Anthropic and Bedrock use message-level, others use content-level + const useMessageLevelOptions = format === "anthropic" || format === "bedrock" + + for (const msg of unique(messagesToCache)) { const shouldUseContentOptions = !useMessageLevelOptions && Array.isArray(msg.content) && msg.content.length > 0 if (shouldUseContentOptions) { @@ -252,15 +293,10 @@ export namespace ProviderTransform { export function message(msgs: ModelMessage[], model: Provider.Model, options: Record) { msgs = unsupportedParts(msgs, model) msgs = normalizeMessages(msgs, model, options) - if ( - (model.providerID === "anthropic" || - model.api.id.includes("anthropic") || - model.api.id.includes("claude") || - model.id.includes("anthropic") || - model.id.includes("claude") || - model.api.npm === "@ai-sdk/anthropic") && - model.api.npm !== "@ai-sdk/gateway" - ) { + + // Apply caching only when explicitly enabled via model.caching + // No auto-detection - user must opt-in via model config + if (model.caching === true || (model.caching && typeof model.caching === "object")) { msgs = applyCaching(msgs, model) } diff --git a/packages/opencode/test/provider/transform.test.ts b/packages/opencode/test/provider/transform.test.ts index 917d357eafa..0188eeb1032 100644 --- a/packages/opencode/test/provider/transform.test.ts +++ b/packages/opencode/test/provider/transform.test.ts @@ -1528,38 +1528,120 @@ describe("ProviderTransform.message - providerOptions key remapping", () => { }) }) -describe("ProviderTransform.message - claude w/bedrock custom inference profile", () => { - test("adds cachePoint", () => { - const model = { - id: "amazon-bedrock/custom-claude-sonnet-4.5", - providerID: "amazon-bedrock", +describe("ProviderTransform.message - bedrock prompt caching", () => { + const createBedrockModel = (apiId: string, providerID = "amazon-bedrock") => + ({ + id: `${providerID}/${apiId}`, + providerID, api: { - id: "arn:aws:bedrock:xxx:yyy:application-inference-profile/zzz", - url: "https://api.test.com", + id: apiId, + url: "https://bedrock.amazonaws.com", npm: "@ai-sdk/amazon-bedrock", }, - name: "Custom inference profile", + name: apiId, capabilities: {}, options: {}, headers: {}, - } as any + }) as any - const msgs = [ - { - role: "user", - content: "Hello", - }, - ] as any[] + test("Claude models on Bedrock get prompt caching", () => { + const model = createBedrockModel("anthropic.claude-3-5-sonnet-20241022-v2:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + test("Amazon Nova models get prompt caching", () => { + const model = createBedrockModel("amazon.nova-pro-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) - expect(result[0].providerOptions?.bedrock).toEqual( - expect.objectContaining({ - cachePoint: { - type: "default", - }, - }), - ) + test("Nova models with nova- prefix get prompt caching", () => { + const model = createBedrockModel("nova-lite-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Llama models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("meta.llama3-70b-instruct-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Mistral models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("mistral.mistral-large-2402-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Cohere models on Bedrock do NOT get prompt caching", () => { + const model = createBedrockModel("cohere.command-r-plus-v1:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Custom ARN with Claude in name gets prompt caching", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-claude-finetune") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Custom ARN without Claude in name does NOT get prompt caching", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-llama-model") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() + }) + + test("Cross-region inference profiles with Claude get prompt caching", () => { + const model = createBedrockModel("us.anthropic.claude-3-5-sonnet-20241022-v2:0") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Application inference profile gets prompt caching when Claude-based", () => { + const model = createBedrockModel("arn:aws:bedrock:us-east-1:123456789:application-inference-profile/my-claude-profile") + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Application inference profile with options.caching=true gets prompt caching", () => { + const model = { + ...createBedrockModel("arn:aws:bedrock:eu-west-1:995555607786:application-inference-profile/bzg00wo23901"), + options: { caching: true }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Custom ARN with options.caching=true gets prompt caching", () => { + const model = { + ...createBedrockModel("arn:aws:bedrock:us-east-1:123456789:custom-model/my-custom-model"), + options: { caching: true }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toEqual({ type: "default" }) + }) + + test("Claude model with options.caching=false does NOT get prompt caching", () => { + const model = { + ...createBedrockModel("anthropic.claude-3-5-sonnet-20241022-v2:0"), + options: { caching: false }, + } + const msgs = [{ role: "user", content: "Hello" }] as any[] + const result = ProviderTransform.message(msgs, model, {}) + expect(result[0].providerOptions?.bedrock?.cachePoint).toBeUndefined() }) })