diff --git a/packages/opencode/src/tool/codesearch.ts b/packages/opencode/src/tool/codesearch.ts index 369cdb45048..0775addd24e 100644 --- a/packages/opencode/src/tool/codesearch.ts +++ b/packages/opencode/src/tool/codesearch.ts @@ -1,14 +1,26 @@ import z from "zod" import { Tool } from "./tool" import DESCRIPTION from "./codesearch.txt" +import { LRUCache } from "../util/cache" +import { retry, isRetryableError } from "@opencode-ai/util/retry" const API_CONFIG = { BASE_URL: "https://mcp.exa.ai", ENDPOINTS: { CONTEXT: "/mcp", }, + DEFAULT_TOKENS: 5000, + TIMEOUT_MS: 30000, } as const +// Cache configuration: 2 hour TTL (code docs change less frequently), 500 entries max +const codeCache = new LRUCache({ + namespace: "codesearch", + maxSize: 500, + ttl: 2 * 60 * 60 * 1000, // 2 hours + persist: true, +}) + interface McpCodeRequest { jsonrpc: string id: number @@ -32,6 +44,16 @@ interface McpCodeResponse { } } +/** + * Generate a cache key from search parameters + */ +function getCacheKey(params: { query: string; tokensNum: number }): string { + return JSON.stringify({ + q: params.query.toLowerCase().trim(), + t: params.tokensNum, + }) +} + export const CodeSearchTool = Tool.define("codesearch", { description: DESCRIPTION, parameters: z.object({ @@ -60,6 +82,19 @@ export const CodeSearchTool = Tool.define("codesearch", { }, }) + const tokensNum = params.tokensNum || API_CONFIG.DEFAULT_TOKENS + + // Check cache first + const cacheKey = getCacheKey({ query: params.query, tokensNum }) + const cached = await codeCache.get(cacheKey) + if (cached) { + return { + output: cached, + title: `Code search: ${params.query} (cached)`, + metadata: { cached: true }, + } + } + const codeRequest: McpCodeRequest = { jsonrpc: "2.0", id: 1, @@ -68,48 +103,71 @@ export const CodeSearchTool = Tool.define("codesearch", { name: "get_code_context_exa", arguments: { query: params.query, - tokensNum: params.tokensNum || 5000, + tokensNum, }, }, } - const controller = new AbortController() - const timeoutId = setTimeout(() => controller.abort(), 30000) - try { - const headers: Record = { - accept: "application/json, text/event-stream", - "content-type": "application/json", - } + const result = await retry( + async () => { + // Create fresh timeout for each retry attempt + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), API_CONFIG.TIMEOUT_MS) - const response = await fetch(`${API_CONFIG.BASE_URL}${API_CONFIG.ENDPOINTS.CONTEXT}`, { - method: "POST", - headers, - body: JSON.stringify(codeRequest), - signal: AbortSignal.any([controller.signal, ctx.abort]), - }) + try { + const headers: Record = { + accept: "application/json, text/event-stream", + "content-type": "application/json", + } - clearTimeout(timeoutId) + const response = await fetch(`${API_CONFIG.BASE_URL}${API_CONFIG.ENDPOINTS.CONTEXT}`, { + method: "POST", + headers, + body: JSON.stringify(codeRequest), + signal: AbortSignal.any([controller.signal, ctx.abort]), + }) - if (!response.ok) { - const errorText = await response.text() - throw new Error(`Code search error (${response.status}): ${errorText}`) - } + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Code search error (${response.status}): ${errorText}`) + } - const responseText = await response.text() - - // Parse SSE response - const lines = responseText.split("\n") - for (const line of lines) { - if (line.startsWith("data: ")) { - const data: McpCodeResponse = JSON.parse(line.substring(6)) - if (data.result && data.result.content && data.result.content.length > 0) { - return { - output: data.result.content[0].text, - title: `Code search: ${params.query}`, - metadata: {}, + const responseText = await response.text() + + // Parse SSE response + const lines = responseText.split("\n") + for (const line of lines) { + if (line.startsWith("data: ")) { + const data: McpCodeResponse = JSON.parse(line.substring(6)) + if (data.result && data.result.content && data.result.content.length > 0) { + return data.result.content[0].text + } + } } + + return null + } finally { + clearTimeout(timeoutId) } + }, + { + attempts: 3, + delay: 1000, + factor: 2, + maxDelay: 10000, + retryIf: isRetryableError, + }, + ) + + if (result) { + // Cache successful results + await codeCache.set(cacheKey, result) + + return { + output: result, + title: `Code search: ${params.query}`, + metadata: { cached: false }, } } @@ -117,11 +175,9 @@ export const CodeSearchTool = Tool.define("codesearch", { output: "No code snippets or documentation found. Please try a different query, be more specific about the library or programming concept, or check the spelling of framework names.", title: `Code search: ${params.query}`, - metadata: {}, + metadata: { cached: false }, } } catch (error) { - clearTimeout(timeoutId) - if (error instanceof Error && error.name === "AbortError") { throw new Error("Code search request timed out") } diff --git a/packages/opencode/src/tool/webfetch.ts b/packages/opencode/src/tool/webfetch.ts index 634c68f4eea..b4a4db153ee 100644 --- a/packages/opencode/src/tool/webfetch.ts +++ b/packages/opencode/src/tool/webfetch.ts @@ -2,11 +2,46 @@ import z from "zod" import { Tool } from "./tool" import TurndownService from "turndown" import DESCRIPTION from "./webfetch.txt" +import { LRUCache } from "../util/cache" +import { retry, isRetryableError } from "@opencode-ai/util/retry" const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds const MAX_TIMEOUT = 120 * 1000 // 2 minutes +// Cache configuration: 30 min TTL (web content can change), 200 entries max +// We cache less aggressively than search since web content is more dynamic +const fetchCache = new LRUCache({ + namespace: "webfetch", + maxSize: 200, + ttl: 30 * 60 * 1000, // 30 minutes + persist: true, +}) + +/** + * Generate a cache key from fetch parameters + */ +function getCacheKey(url: string, format: string): string { + return JSON.stringify({ u: url, f: format }) +} + +/** + * Check if URL should be cached (skip dynamic content indicators) + */ +function shouldCache(url: string): boolean { + const skipPatterns = [ + /\bapi\b/i, + /\bgraphql\b/i, + /\bauth\b/i, + /\blogin\b/i, + /\bsession\b/i, + /\btoken\b/i, + /\brandom\b/i, + /\btimestamp\b/i, + ] + return !skipPatterns.some((pattern) => pattern.test(url)) +} + export const WebFetchTool = Tool.define("webfetch", { description: DESCRIPTION, parameters: z.object({ @@ -34,10 +69,21 @@ export const WebFetchTool = Tool.define("webfetch", { }, }) - const timeout = Math.min((params.timeout ?? DEFAULT_TIMEOUT / 1000) * 1000, MAX_TIMEOUT) + // Check cache first for cacheable URLs + const canCache = shouldCache(params.url) + if (canCache) { + const cacheKey = getCacheKey(params.url, params.format) + const cached = await fetchCache.get(cacheKey) + if (cached) { + return { + output: cached, + title: `${params.url} (cached)`, + metadata: { cached: true }, + } + } + } - const controller = new AbortController() - const timeoutId = setTimeout(() => controller.abort(), timeout) + const timeout = Math.min((params.timeout ?? DEFAULT_TIMEOUT / 1000) * 1000, MAX_TIMEOUT) // Build Accept header based on requested format with q parameters for fallbacks let acceptHeader = "*/*" @@ -56,83 +102,102 @@ export const WebFetchTool = Tool.define("webfetch", { "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8" } - const response = await fetch(params.url, { - signal: AbortSignal.any([controller.signal, ctx.abort]), - headers: { - "User-Agent": - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", - Accept: acceptHeader, - "Accept-Language": "en-US,en;q=0.9", - }, - }) - - clearTimeout(timeoutId) - - if (!response.ok) { - throw new Error(`Request failed with status code: ${response.status}`) - } - - // Check content length - const contentLength = response.headers.get("content-length") - if (contentLength && parseInt(contentLength) > MAX_RESPONSE_SIZE) { - throw new Error("Response too large (exceeds 5MB limit)") - } - - const arrayBuffer = await response.arrayBuffer() - if (arrayBuffer.byteLength > MAX_RESPONSE_SIZE) { - throw new Error("Response too large (exceeds 5MB limit)") - } - - const content = new TextDecoder().decode(arrayBuffer) - const contentType = response.headers.get("content-type") || "" - - const title = `${params.url} (${contentType})` - - // Handle content based on requested format and actual content type - switch (params.format) { - case "markdown": - if (contentType.includes("text/html")) { - const markdown = convertHTMLToMarkdown(content) - return { - output: markdown, - title, - metadata: {}, + try { + const result = await retry( + async () => { + // Create fresh timeout for each retry attempt + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), timeout) + + try { + const response = await fetch(params.url, { + signal: AbortSignal.any([controller.signal, ctx.abort]), + headers: { + "User-Agent": + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + Accept: acceptHeader, + "Accept-Language": "en-US,en;q=0.9", + }, + }) + + if (!response.ok) { + throw new Error(`Request failed with status code: ${response.status}`) + } + + // Check content length + const contentLength = response.headers.get("content-length") + if (contentLength && parseInt(contentLength) > MAX_RESPONSE_SIZE) { + throw new Error("Response too large (exceeds 5MB limit)") + } + + const arrayBuffer = await response.arrayBuffer() + if (arrayBuffer.byteLength > MAX_RESPONSE_SIZE) { + throw new Error("Response too large (exceeds 5MB limit)") + } + + const content = new TextDecoder().decode(arrayBuffer) + const contentType = response.headers.get("content-type") || "" + + return { content, contentType } + } finally { + clearTimeout(timeoutId) } - } - return { - output: content, - title, - metadata: {}, - } - - case "text": - if (contentType.includes("text/html")) { - const text = await extractTextFromHTML(content) - return { - output: text, - title, - metadata: {}, + }, + { + attempts: 3, + delay: 500, + factor: 2, + maxDelay: 5000, + retryIf: isRetryableError, + }, + ) + + const { content, contentType } = result + const title = `${params.url} (${contentType})` + + let output: string + + // Handle content based on requested format and actual content type + switch (params.format) { + case "markdown": + if (contentType.includes("text/html")) { + output = convertHTMLToMarkdown(content) + } else { + output = content } - } - return { - output: content, - title, - metadata: {}, - } - - case "html": - return { - output: content, - title, - metadata: {}, - } + break - default: - return { - output: content, - title, - metadata: {}, - } + case "text": + if (contentType.includes("text/html")) { + output = await extractTextFromHTML(content) + } else { + output = content + } + break + + case "html": + default: + output = content + break + } + + // Cache successful results for cacheable URLs + if (canCache) { + const cacheKey = getCacheKey(params.url, params.format) + await fetchCache.set(cacheKey, output) + } + + return { + output, + title, + metadata: { cached: false }, + } + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + throw new Error("Request timed out") + } + + throw error } }, }) diff --git a/packages/opencode/src/tool/websearch.ts b/packages/opencode/src/tool/websearch.ts index f6df36f10f9..71f1a8ee75e 100644 --- a/packages/opencode/src/tool/websearch.ts +++ b/packages/opencode/src/tool/websearch.ts @@ -1,6 +1,8 @@ import z from "zod" import { Tool } from "./tool" import DESCRIPTION from "./websearch.txt" +import { LRUCache } from "../util/cache" +import { retry, isRetryableError } from "@opencode-ai/util/retry" const API_CONFIG = { BASE_URL: "https://mcp.exa.ai", @@ -8,8 +10,17 @@ const API_CONFIG = { SEARCH: "/mcp", }, DEFAULT_NUM_RESULTS: 8, + TIMEOUT_MS: 25000, } as const +// Cache configuration: 1 hour TTL, 500 entries max +const searchCache = new LRUCache({ + namespace: "websearch", + maxSize: 500, + ttl: 60 * 60 * 1000, // 1 hour + persist: true, +}) + interface McpSearchRequest { jsonrpc: string id: number @@ -36,6 +47,25 @@ interface McpSearchResponse { } } +/** + * Generate a cache key from search parameters + */ +function getCacheKey(params: { + query: string + numResults?: number + livecrawl?: string + type?: string + contextMaxCharacters?: number +}): string { + return JSON.stringify({ + q: params.query.toLowerCase().trim(), + n: params.numResults ?? API_CONFIG.DEFAULT_NUM_RESULTS, + l: params.livecrawl ?? "fallback", + t: params.type ?? "auto", + c: params.contextMaxCharacters, + }) +} + export const WebSearchTool = Tool.define("websearch", { description: DESCRIPTION, parameters: z.object({ @@ -70,6 +100,19 @@ export const WebSearchTool = Tool.define("websearch", { }, }) + // Check cache first (only for non-preferred livecrawl) + if (params.livecrawl !== "preferred") { + const cacheKey = getCacheKey(params) + const cached = await searchCache.get(cacheKey) + if (cached) { + return { + output: cached, + title: `Web search: ${params.query} (cached)`, + metadata: { cached: true }, + } + } + } + const searchRequest: McpSearchRequest = { jsonrpc: "2.0", id: 1, @@ -86,54 +129,76 @@ export const WebSearchTool = Tool.define("websearch", { }, } - const controller = new AbortController() - const timeoutId = setTimeout(() => controller.abort(), 25000) - try { - const headers: Record = { - accept: "application/json, text/event-stream", - "content-type": "application/json", - } + const result = await retry( + async () => { + // Create fresh timeout for each retry attempt + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), API_CONFIG.TIMEOUT_MS) - const response = await fetch(`${API_CONFIG.BASE_URL}${API_CONFIG.ENDPOINTS.SEARCH}`, { - method: "POST", - headers, - body: JSON.stringify(searchRequest), - signal: AbortSignal.any([controller.signal, ctx.abort]), - }) + try { + const headers: Record = { + accept: "application/json, text/event-stream", + "content-type": "application/json", + } - clearTimeout(timeoutId) + const response = await fetch(`${API_CONFIG.BASE_URL}${API_CONFIG.ENDPOINTS.SEARCH}`, { + method: "POST", + headers, + body: JSON.stringify(searchRequest), + signal: AbortSignal.any([controller.signal, ctx.abort]), + }) - if (!response.ok) { - const errorText = await response.text() - throw new Error(`Search error (${response.status}): ${errorText}`) - } + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Search error (${response.status}): ${errorText}`) + } - const responseText = await response.text() - - // Parse SSE response - const lines = responseText.split("\n") - for (const line of lines) { - if (line.startsWith("data: ")) { - const data: McpSearchResponse = JSON.parse(line.substring(6)) - if (data.result && data.result.content && data.result.content.length > 0) { - return { - output: data.result.content[0].text, - title: `Web search: ${params.query}`, - metadata: {}, + const responseText = await response.text() + + // Parse SSE response + const lines = responseText.split("\n") + for (const line of lines) { + if (line.startsWith("data: ")) { + const data: McpSearchResponse = JSON.parse(line.substring(6)) + if (data.result && data.result.content && data.result.content.length > 0) { + return data.result.content[0].text + } + } } + + return null + } finally { + clearTimeout(timeoutId) } + }, + { + attempts: 3, + delay: 1000, + factor: 2, + maxDelay: 10000, + retryIf: isRetryableError, + }, + ) + + if (result) { + // Cache successful results + const cacheKey = getCacheKey(params) + await searchCache.set(cacheKey, result) + + return { + output: result, + title: `Web search: ${params.query}`, + metadata: { cached: false }, } } return { output: "No search results found. Please try a different query.", title: `Web search: ${params.query}`, - metadata: {}, + metadata: { cached: false }, } } catch (error) { - clearTimeout(timeoutId) - if (error instanceof Error && error.name === "AbortError") { throw new Error("Search request timed out") } diff --git a/packages/opencode/src/util/cache.ts b/packages/opencode/src/util/cache.ts new file mode 100644 index 00000000000..9f6831434fd --- /dev/null +++ b/packages/opencode/src/util/cache.ts @@ -0,0 +1,237 @@ +import fs from "fs/promises" +import path from "path" +import { Global } from "../global" +import { createHash } from "crypto" +import { Log } from "./log" + +export interface CacheOptions { + /** Maximum number of items in memory cache */ + maxSize?: number + /** Time-to-live in milliseconds (default: 1 hour) */ + ttl?: number + /** Whether to persist to disk (default: true) */ + persist?: boolean + /** Cache namespace/subdirectory */ + namespace: string +} + +interface CacheEntry { + value: T + timestamp: number + ttl: number +} + +export interface CacheStats { + namespace: string + memorySize: number + maxSize: number + hits: number + misses: number + hitRate: number +} + +const log = Log.create({ service: "cache" }) + +/** + * LRU Cache with optional disk persistence + * - In-memory LRU cache for fast access + * - Disk persistence for cache survival across restarts + * - TTL support for automatic expiration + */ +export class LRUCache { + private cache: Map> = new Map() + private readonly maxSize: number + private readonly ttl: number + private readonly persist: boolean + private readonly cacheDir: string + private readonly namespace: string + private initialized = false + private hits = 0 + private misses = 0 + + constructor(options: CacheOptions) { + this.namespace = options.namespace + this.maxSize = options.maxSize ?? 1000 + this.ttl = options.ttl ?? 60 * 60 * 1000 // 1 hour default + this.persist = options.persist ?? true + this.cacheDir = path.join(Global.Path.cache, options.namespace) + } + + private async ensureInit(): Promise { + if (this.initialized) return + if (this.persist) { + await fs.mkdir(this.cacheDir, { recursive: true }) + } + this.initialized = true + } + + private hashKey(key: string): string { + return createHash("sha256").update(key).digest("hex").slice(0, 16) + } + + private isExpired(entry: CacheEntry): boolean { + return Date.now() - entry.timestamp > entry.ttl + } + + private evictOldest(): void { + if (this.cache.size >= this.maxSize) { + // Map maintains insertion order, first key is oldest + const firstKey = this.cache.keys().next().value + if (firstKey) { + this.cache.delete(firstKey) + } + } + } + + /** + * Get a value from cache + * First checks memory, then disk if persistence is enabled + */ + async get(key: string): Promise { + await this.ensureInit() + + const hashedKey = this.hashKey(key) + + // Check memory cache first + const memEntry = this.cache.get(hashedKey) + if (memEntry) { + if (this.isExpired(memEntry)) { + this.cache.delete(hashedKey) + if (this.persist) { + await this.deleteFromDisk(hashedKey).catch(() => {}) + } + this.misses++ + log.info("cache miss (expired)", { namespace: this.namespace }) + return undefined + } + // Move to end (most recently used) + this.cache.delete(hashedKey) + this.cache.set(hashedKey, memEntry) + this.hits++ + log.info("cache hit", { namespace: this.namespace, source: "memory" }) + return memEntry.value + } + + // Check disk cache if persistence is enabled + if (this.persist) { + const diskEntry = await this.readFromDisk(hashedKey) + if (diskEntry) { + if (this.isExpired(diskEntry)) { + await this.deleteFromDisk(hashedKey).catch(() => {}) + this.misses++ + log.info("cache miss (expired)", { namespace: this.namespace }) + return undefined + } + // Add back to memory cache + this.evictOldest() + this.cache.set(hashedKey, diskEntry) + this.hits++ + log.info("cache hit", { namespace: this.namespace, source: "disk" }) + return diskEntry.value + } + } + + this.misses++ + log.info("cache miss", { namespace: this.namespace }) + return undefined + } + + /** + * Set a value in cache + * Stores in memory and optionally persists to disk + */ + async set(key: string, value: T, ttl?: number): Promise { + await this.ensureInit() + + const hashedKey = this.hashKey(key) + const entry: CacheEntry = { + value, + timestamp: Date.now(), + ttl: ttl ?? this.ttl, + } + + // Evict oldest if at capacity + this.evictOldest() + + // Set in memory + this.cache.set(hashedKey, entry) + + // Persist to disk if enabled + if (this.persist) { + await this.writeToDisk(hashedKey, entry).catch(() => { + // Silently fail disk writes - memory cache still works + }) + } + } + + /** + * Check if a key exists and is not expired + */ + async has(key: string): Promise { + const value = await this.get(key) + return value !== undefined + } + + /** + * Delete a key from cache + */ + async delete(key: string): Promise { + const hashedKey = this.hashKey(key) + this.cache.delete(hashedKey) + if (this.persist) { + await this.deleteFromDisk(hashedKey).catch(() => {}) + } + } + + /** + * Clear all cache entries + */ + async clear(): Promise { + this.cache.clear() + this.hits = 0 + this.misses = 0 + if (this.persist) { + try { + const files = await fs.readdir(this.cacheDir) + await Promise.all(files.map((file) => fs.unlink(path.join(this.cacheDir, file)).catch(() => {}))) + } catch { + // Directory might not exist + } + } + } + + /** + * Get cache statistics + */ + stats(): CacheStats { + const total = this.hits + this.misses + return { + namespace: this.namespace, + memorySize: this.cache.size, + maxSize: this.maxSize, + hits: this.hits, + misses: this.misses, + hitRate: total > 0 ? this.hits / total : 0, + } + } + + private async readFromDisk(hashedKey: string): Promise | undefined> { + try { + const filePath = path.join(this.cacheDir, `${hashedKey}.json`) + const content = await fs.readFile(filePath, "utf-8") + return JSON.parse(content) as CacheEntry + } catch { + return undefined + } + } + + private async writeToDisk(hashedKey: string, entry: CacheEntry): Promise { + const filePath = path.join(this.cacheDir, `${hashedKey}.json`) + await fs.writeFile(filePath, JSON.stringify(entry), "utf-8") + } + + private async deleteFromDisk(hashedKey: string): Promise { + const filePath = path.join(this.cacheDir, `${hashedKey}.json`) + await fs.unlink(filePath) + } +} diff --git a/packages/opencode/test/tool/cache.test.ts b/packages/opencode/test/tool/cache.test.ts new file mode 100644 index 00000000000..d8b199c97db --- /dev/null +++ b/packages/opencode/test/tool/cache.test.ts @@ -0,0 +1,159 @@ +import { test, expect } from "bun:test" +import { LRUCache } from "../../src/util/cache" + +test("LRUCache: basic get/set operations", async () => { + const cache = new LRUCache({ + namespace: "test-basic", + maxSize: 10, + ttl: 60000, + persist: false, + }) + + await cache.set("key1", "value1") + const result = await cache.get("key1") + expect(result).toBe("value1") + + const missing = await cache.get("nonexistent") + expect(missing).toBeUndefined() +}) + +test("LRUCache: respects maxSize and evicts oldest", async () => { + const cache = new LRUCache({ + namespace: "test-eviction", + maxSize: 3, + ttl: 60000, + persist: false, + }) + + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + await cache.set("key4", "value4") + + expect(await cache.get("key1")).toBeUndefined() + expect(await cache.get("key2")).toBe("value2") + expect(await cache.get("key3")).toBe("value3") + expect(await cache.get("key4")).toBe("value4") +}) + +test("LRUCache: TTL expiration", async () => { + const cache = new LRUCache({ + namespace: "test-ttl", + maxSize: 10, + ttl: 50, + persist: false, + }) + + await cache.set("key1", "value1") + expect(await cache.get("key1")).toBe("value1") + + await new Promise((resolve) => setTimeout(resolve, 100)) + + expect(await cache.get("key1")).toBeUndefined() +}) + +test("LRUCache: stats reporting with namespace", async () => { + const cache = new LRUCache({ + namespace: "test-stats-ns", + maxSize: 100, + ttl: 60000, + persist: false, + }) + + await cache.set("key1", "value1") + await cache.set("key2", "value2") + + const stats = cache.stats() + expect(stats.namespace).toBe("test-stats-ns") + expect(stats.memorySize).toBe(2) + expect(stats.maxSize).toBe(100) +}) + +test("LRUCache: hit/miss counters", async () => { + const cache = new LRUCache({ + namespace: "test-counters", + maxSize: 10, + ttl: 60000, + persist: false, + }) + + // Initial state - no hits or misses + let stats = cache.stats() + expect(stats.hits).toBe(0) + expect(stats.misses).toBe(0) + expect(stats.hitRate).toBe(0) + + // Set a value + await cache.set("key1", "value1") + + // Get existing key - should be a hit + await cache.get("key1") + stats = cache.stats() + expect(stats.hits).toBe(1) + expect(stats.misses).toBe(0) + expect(stats.hitRate).toBe(1) + + // Get non-existent key - should be a miss + await cache.get("nonexistent") + stats = cache.stats() + expect(stats.hits).toBe(1) + expect(stats.misses).toBe(1) + expect(stats.hitRate).toBe(0.5) + + // Another hit + await cache.get("key1") + stats = cache.stats() + expect(stats.hits).toBe(2) + expect(stats.misses).toBe(1) + expect(stats.hitRate).toBeCloseTo(0.666, 2) +}) + +test("LRUCache: clear resets counters", async () => { + const cache = new LRUCache({ + namespace: "test-clear-counters", + maxSize: 10, + ttl: 60000, + persist: false, + }) + + await cache.set("key1", "value1") + await cache.get("key1") // hit + await cache.get("missing") // miss + + let stats = cache.stats() + expect(stats.hits).toBe(1) + expect(stats.misses).toBe(1) + + await cache.clear() + + stats = cache.stats() + expect(stats.hits).toBe(0) + expect(stats.misses).toBe(0) + expect(stats.memorySize).toBe(0) +}) + +test("LRUCache: expired entries count as misses", async () => { + const cache = new LRUCache({ + namespace: "test-expired-miss", + maxSize: 10, + ttl: 50, // 50ms TTL + persist: false, + }) + + await cache.set("key1", "value1") + + // Hit while fresh + await cache.get("key1") + let stats = cache.stats() + expect(stats.hits).toBe(1) + expect(stats.misses).toBe(0) + + // Wait for expiration + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Miss due to expiration + await cache.get("key1") + stats = cache.stats() + expect(stats.hits).toBe(1) + expect(stats.misses).toBe(1) +}) diff --git a/packages/opencode/test/util/retry.test.ts b/packages/opencode/test/util/retry.test.ts new file mode 100644 index 00000000000..e86c31ecbb7 --- /dev/null +++ b/packages/opencode/test/util/retry.test.ts @@ -0,0 +1,166 @@ +import { test, expect } from "bun:test" +import { retry, isTransientError, isRateLimitError, isServerError, isRetryableError } from "@opencode-ai/util/retry" + +// Test error detection functions +test("isTransientError: detects network errors", () => { + expect(isTransientError(new Error("Failed to fetch"))).toBe(true) + expect(isTransientError(new Error("Network request failed"))).toBe(true) + expect(isTransientError(new Error("ECONNRESET"))).toBe(true) + expect(isTransientError(new Error("ECONNREFUSED"))).toBe(true) + expect(isTransientError(new Error("ETIMEDOUT"))).toBe(true) + expect(isTransientError(new Error("Socket hang up"))).toBe(true) + expect(isTransientError(new Error("load failed"))).toBe(true) + + // Non-transient errors + expect(isTransientError(new Error("Not found"))).toBe(false) + expect(isTransientError(new Error("Invalid input"))).toBe(false) + expect(isTransientError(null)).toBe(false) + expect(isTransientError(undefined)).toBe(false) +}) + +test("isRateLimitError: detects rate limit errors", () => { + expect(isRateLimitError(new Error("Error 429: Too many requests"))).toBe(true) + expect(isRateLimitError(new Error("Rate limit exceeded"))).toBe(true) + expect(isRateLimitError(new Error("Too many requests"))).toBe(true) + + // Not rate limit errors + expect(isRateLimitError(new Error("Not found"))).toBe(false) + expect(isRateLimitError(new Error("500 Internal Server Error"))).toBe(false) +}) + +test("isServerError: detects server errors", () => { + expect(isServerError(new Error("500 Internal Server Error"))).toBe(true) + expect(isServerError(new Error("Error 502: Bad Gateway"))).toBe(true) + expect(isServerError(new Error("503 Service Unavailable"))).toBe(true) + expect(isServerError(new Error("504 Gateway Timeout"))).toBe(true) + + // Not server errors + expect(isServerError(new Error("404 Not Found"))).toBe(false) + expect(isServerError(new Error("400 Bad Request"))).toBe(false) + expect(isServerError(new Error("Network error"))).toBe(false) +}) + +test("isRetryableError: combines all retryable error types", () => { + // Transient errors + expect(isRetryableError(new Error("Failed to fetch"))).toBe(true) + expect(isRetryableError(new Error("ECONNRESET"))).toBe(true) + + // Rate limit errors + expect(isRetryableError(new Error("429 Too many requests"))).toBe(true) + + // Server errors + expect(isRetryableError(new Error("500 Internal Server Error"))).toBe(true) + expect(isRetryableError(new Error("503 Service Unavailable"))).toBe(true) + + // Not retryable + expect(isRetryableError(new Error("404 Not Found"))).toBe(false) + expect(isRetryableError(new Error("Invalid input"))).toBe(false) +}) + +// Test retry function +test("retry: succeeds on first attempt", async () => { + let attempts = 0 + const result = await retry(async () => { + attempts++ + return "success" + }) + + expect(result).toBe("success") + expect(attempts).toBe(1) +}) + +test("retry: retries on transient error and succeeds", async () => { + let attempts = 0 + const result = await retry( + async () => { + attempts++ + if (attempts < 3) { + throw new Error("Failed to fetch") + } + return "success" + }, + { attempts: 5, delay: 10 }, + ) + + expect(result).toBe("success") + expect(attempts).toBe(3) +}) + +test("retry: throws after max attempts", async () => { + let attempts = 0 + + await expect( + retry( + async () => { + attempts++ + throw new Error("Failed to fetch") + }, + { attempts: 3, delay: 10 }, + ), + ).rejects.toThrow("Failed to fetch") + + expect(attempts).toBe(3) +}) + +test("retry: does not retry non-transient errors by default", async () => { + let attempts = 0 + + await expect( + retry( + async () => { + attempts++ + throw new Error("Invalid input") + }, + { attempts: 3, delay: 10 }, + ), + ).rejects.toThrow("Invalid input") + + expect(attempts).toBe(1) // No retries for non-transient error +}) + +test("retry: uses custom retryIf function", async () => { + let attempts = 0 + + const result = await retry( + async () => { + attempts++ + if (attempts < 2) { + throw new Error("Custom retryable error") + } + return "success" + }, + { + attempts: 3, + delay: 10, + retryIf: (error) => error instanceof Error && error.message.includes("Custom"), + }, + ) + + expect(result).toBe("success") + expect(attempts).toBe(2) +}) + +test("retry: uses isRetryableError for HTTP errors", async () => { + let attempts = 0 + + const result = await retry( + async () => { + attempts++ + if (attempts === 1) { + throw new Error("503 Service Unavailable") + } + if (attempts === 2) { + throw new Error("429 Too many requests") + } + return "success" + }, + { + attempts: 5, + delay: 10, + retryIf: isRetryableError, + }, + ) + + expect(result).toBe("success") + expect(attempts).toBe(3) +}) diff --git a/packages/util/src/retry.ts b/packages/util/src/retry.ts index 0014a604c93..268ed3981f1 100644 --- a/packages/util/src/retry.ts +++ b/packages/util/src/retry.ts @@ -17,12 +17,53 @@ const TRANSIENT_MESSAGES = [ "socket hang up", ] -function isTransientError(error: unknown): boolean { +const RATE_LIMIT_MESSAGES = ["429", "rate limit", "too many requests"] + +const SERVER_ERROR_MESSAGES = [ + "500", + "502", + "503", + "504", + "internal server error", + "bad gateway", + "service unavailable", +] + +/** + * Check if an error is a transient network error + */ +export function isTransientError(error: unknown): boolean { if (!error) return false const message = String(error instanceof Error ? error.message : error).toLowerCase() return TRANSIENT_MESSAGES.some((m) => message.includes(m)) } +/** + * Check if an error is a rate limit error (HTTP 429) + */ +export function isRateLimitError(error: unknown): boolean { + if (!error) return false + const message = String(error instanceof Error ? error.message : error).toLowerCase() + return RATE_LIMIT_MESSAGES.some((m) => message.includes(m)) +} + +/** + * Check if an error is a server error (HTTP 5xx) + */ +export function isServerError(error: unknown): boolean { + if (!error) return false + const message = String(error instanceof Error ? error.message : error).toLowerCase() + return SERVER_ERROR_MESSAGES.some((m) => message.includes(m)) +} + +/** + * Check if an error is retryable (network issues, rate limits, or server errors) + * Use this for HTTP/fetch operations where you want to retry on transient failures + */ +export function isRetryableError(error: unknown): boolean { + return isTransientError(error) || isRateLimitError(error) || isServerError(error) +} + export async function retry(fn: () => Promise, options: RetryOptions = {}): Promise { const { attempts = 3, delay = 500, factor = 2, maxDelay = 10000, retryIf = isTransientError } = options