diff --git a/plugins/redis/README.md b/plugins/redis/README.md new file mode 100644 index 0000000000..066e17f772 --- /dev/null +++ b/plugins/redis/README.md @@ -0,0 +1,364 @@ +# Redis Cache Plugin for Bifrost + +This plugin provides Redis-based caching functionality for Bifrost requests. It caches responses based on request body hashes and returns cached responses for identical requests, significantly improving performance and reducing API costs. + +## Features + +- **High-Performance Hashing**: Uses xxhash for ultra-fast request body hashing +- **Asynchronous Caching**: Non-blocking cache writes for optimal response times +- **Response Caching**: Stores complete responses in Redis with configurable TTL +- **Streaming Cache Support**: Caches and retrieves streaming responses chunk by chunk +- **Cache Hit Detection**: Returns cached responses for identical requests +- **Intelligent Cache Recovery**: Automatically reconstructs streaming responses from cached chunks +- **Simple Setup**: Only requires Redis address and cache key - sensible defaults for everything else +- **Self-Contained**: Creates and manages its own Redis client + +## Installation + +```bash +go get github.com/maximhq/bifrost/core +go get github.com/maximhq/bifrost/plugins/redis +``` + +## Quick Start + +### Basic Setup + +```go +import ( + "github.com/maximhq/bifrost/plugins/redis" + bifrost "github.com/maximhq/bifrost/core" +) + +// Simple configuration - only Redis address and cache key are required! +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", // Your Redis server address + CacheKey: "x-my-cache-key", // Context key for cache identification +} + +// Create the plugin +plugin, err := redis.NewRedisPlugin(config, logger) +if err != nil { + log.Fatal("Failed to create Redis plugin:", err) +} + +// Use with Bifrost +bifrostConfig := schemas.BifrostConfig{ + Account: yourAccount, + Plugins: []schemas.Plugin{plugin}, + // ... other config +} +``` + +That's it! The plugin uses Redis client defaults for connection handling and these defaults for caching: + +- **TTL**: 5 minutes +- **CacheByModel**: true (include model in cache key) +- **CacheByProvider**: true (include provider in cache key) + +**Important**: You must provide the cache key in your request context for caching to work: + +```go +ctx := context.WithValue(ctx, redis.ContextKey("x-my-cache-key"), "cache-value") +response, err := client.ChatCompletionRequest(ctx, request) +``` + +### With Password Authentication + +```go +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + Password: "your-redis-password", +} +``` + +### With Custom TTL and Prefix + +```go +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + TTL: time.Hour, // Cache for 1 hour + Prefix: "myapp:cache:", // Custom prefix +} +``` + +### Per-Request TTL Override (via Context) + +You can override the cache TTL for individual requests by providing a TTL in the request context. Configure a `CacheTTLKey` on the plugin, then set a `time.Duration` value at that context key before making the request. + +```go +// Configure plugin with a context key used to read per-request TTLs +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + CacheTTLKey: "x-my-cache-ttl", // The context key for reading TTL + TTL: 5 * time.Minute, // Fallback/default TTL +} + +plugin, err := redis.NewRedisPlugin(config, logger) +// ... init Bifrost client with plugin + +// Before making a request, set a per-request TTL +ctx := context.WithValue(ctx, redis.ContextKey("x-my-cache-ttl"), 30*time.Second) +resp, err := client.ChatCompletionRequest(ctx, request) +``` + +Notes: + +- The context value must be of type `time.Duration`. If it is missing or of the wrong type, the plugin falls back to `config.TTL`. +- This applies to both regular and streaming requests. For streaming, the same per-request TTL applies to all chunks. + +### With Different Database + +```go +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + DB: 1, // Use Redis database 1 +} +``` + +### Streaming Cache Example + +```go +// Configure plugin for streaming cache +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-stream-cache-key", + TTL: 30 * time.Minute, // Cache streaming responses for 30 minutes +} + +// Use with streaming requests +ctx := context.WithValue(ctx, redis.ContextKey("x-stream-cache-key"), "stream-session-1") +stream, err := client.ChatCompletionStreamRequest(ctx, request) +// Subsequent identical requests will be served from cache as a reconstructed stream +``` + +### Custom Cache Key Configuration + +```go +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + CacheByModel: bifrost.Ptr(false), // Don't include model in cache key + CacheByProvider: bifrost.Ptr(true), // Include provider in cache key +} +``` + +### Custom Redis Client Configuration + +```go +config := redis.RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: "x-my-cache-key", + PoolSize: 20, // Custom connection pool size + DialTimeout: 5 * time.Second, // Custom connection timeout + ReadTimeout: 3 * time.Second, // Custom read timeout + ConnMaxLifetime: time.Hour, // Custom connection lifetime +} +``` + +## Configuration Options + +| Option | Type | Required | Default | Description | +| ----------------- | --------------- | -------- | ----------------- | ----------------------------------- | +| `Addr` | `string` | ✅ | - | Redis server address (host:port) | +| `CacheKey` | `string` | ✅ | - | Context key for cache identification| +| `Username` | `string` | ❌ | `""` | Username for Redis AUTH (Redis 6+) | +| `Password` | `string` | ❌ | `""` | Password for Redis AUTH | +| `DB` | `int` | ❌ | `0` | Redis database number | +| `TTL` | `time.Duration` | ❌ | `5 * time.Minute` | Time-to-live for cached responses | +| `Prefix` | `string` | ❌ | `""` | Prefix for cache keys | +| `CacheByModel` | `*bool` | ❌ | `true` | Include model in cache key | +| `CacheByProvider` | `*bool` | ❌ | `true` | Include provider in cache key | + +**Redis Connection Options** (all optional, Redis client uses its own defaults for zero values): + +- `PoolSize`, `MinIdleConns`, `MaxIdleConns` - Connection pool settings +- `ConnMaxLifetime`, `ConnMaxIdleTime` - Connection lifetime settings +- `DialTimeout`, `ReadTimeout`, `WriteTimeout` - Timeout settings + +All Redis configuration values are passed directly to the Redis client, which handles its own zero-value defaults. You only need to specify values you want to override from Redis client defaults. + +## How It Works + +The plugin generates an xxhash of the normalized request including: + +- Provider (if CacheByProvider is true) +- Model (if CacheByModel is true) +- Input (chat completion or text completion) +- Parameters (includes tool calls) + +Identical requests will always produce the same hash, enabling effective caching. + +### Caching Flow + +#### Regular Requests + +1. **PreHook**: Checks Redis for cached response, returns immediately if found +2. **PostHook**: Stores the response in Redis asynchronously (non-blocking) +3. **Cleanup**: Clears all cached entries and closes connection on shutdown + +#### Streaming Requests + +1. **PreHook**: Checks Redis for cached chunks using pattern `{cache_key}_chunk_*` +2. **Cache Hit**: Reconstructs stream from cached chunks in correct order +3. **PostHook**: Stores each stream chunk with index: `{cache_key}_chunk_{index}` +4. **Stream Reconstruction**: Subsequent requests get sorted chunks as a new stream + +**Asynchronous Caching**: Cache writes happen in background goroutines with a 30-second timeout, ensuring responses are never delayed by Redis operations. This provides optimal performance while maintaining cache functionality. + +**Streaming Intelligence**: The plugin automatically detects streaming requests and handles chunk-based caching. Each chunk is stored with its index, allowing perfect reconstruction of the original stream order. + +### Cache Keys + +#### Regular Responses + +Cache keys follow the pattern: `{prefix}{cache_value}_{xxhash}` + +Example: `bifrost:cache:my-session_a1b2c3d4e5f6...` + +#### Streaming Responses + +Chunk keys follow the pattern: `{prefix}{cache_value}_{xxhash}_chunk_{index}` + +Examples: + +- `bifrost:cache:my-session_a1b2c3d4e5f6..._chunk_0` +- `bifrost:cache:my-session_a1b2c3d4e5f6..._chunk_1` +- `bifrost:cache:my-session_a1b2c3d4e5f6..._chunk_2` + +## Manual Cache Invalidation + +You can invalidate specific cached entries at runtime using the method `ClearCacheForKey(key string)` on the concrete `redis.Plugin` type. This deletes the provided key and, if it corresponds to a streaming response, all of its chunk entries (`_chunk_*`). + +### Getting the cache key from responses + +- **Regular responses**: When a response is served from cache, the plugin adds metadata to `response.ExtraFields.RawResponse`: + - `bifrost_cached: true` + - `bifrost_cache_key: "_"` + Use this `bifrost_cache_key` as the argument to `ClearCacheForKey`. + +- **Streaming responses**: Cached stream chunks include `bifrost_cache_key` for the specific chunk, in the form `"_chunk_{index}"`. To invalidate the entire stream cache, strip the `"_chunk_{index}"` suffix to obtain the base key and pass that base key to `ClearCacheForKey`. + +### Examples + +```go +// Non-streaming: clear the cached response you just used +resp, err := client.ChatCompletionRequest(ctx, req) +if err != nil { + // handle error +} + +if resp != nil && resp.ExtraFields.RawResponse != nil { + if raw, ok := resp.ExtraFields.RawResponse.(map[string]interface{}); ok { + if keyAny, ok := raw["bifrost_cache_key"]; ok { + if pluginImpl, ok := plugin.(*redis.Plugin); ok { + _ = pluginImpl.ClearCacheForKey(keyAny.(string)) + } + } + } +} +``` + +```go +// Streaming: clear all chunks for a cached stream +for msg := range stream { + if msg.BifrostResponse == nil { + continue + } + raw := msg.BifrostResponse.ExtraFields.RawResponse + rawMap, ok := raw.(map[string]interface{}) + if !ok { + continue + } + keyAny, ok := rawMap["bifrost_cache_key"] + if !ok { + continue + } + chunkKey := keyAny.(string) // e.g., "_chunk_3" + + // Derive base key by removing the trailing "_chunk_{index}" part + baseKey := chunkKey + if idx := strings.LastIndex(chunkKey, "_chunk_"); idx != -1 { + baseKey = chunkKey[:idx] + } + + if pluginImpl, ok := plugin.(*redis.Plugin); ok { + _ = pluginImpl.ClearCacheForKey(baseKey) + } + break // we only need one chunk to compute the base key +} +``` + +To clear all entries managed by this plugin (by prefix), call `Cleanup()` during shutdown: + +```go +_ = plugin.(*redis.Plugin).Cleanup() +``` + +## Testing + +The plugin includes comprehensive tests for both regular and streaming cache functionality. + +Run the tests with a Redis instance running: + +```bash +# Start Redis (using Docker) +docker run -d -p 6379:6379 redis:latest + +# Run all tests +go test -v + +# Run specific tests +go test -run TestRedisPlugin -v # Test regular caching +go test -run TestRedisPluginStreaming -v # Test streaming cache +``` + +Tests will be skipped if Redis is not available. The tests validate: + +- Cache hit/miss behavior +- Performance improvements (cache should be significantly faster) +- Content integrity (cached responses match originals) +- Streaming chunk ordering and reconstruction +- Provider information preservation + +## Performance Benefits + +- **Reduced API Calls**: Identical requests are served from cache +- **Ultra-Low Latency**: Cache hits return immediately, cache writes are non-blocking +- **Streaming Efficiency**: Cached streams are reconstructed and delivered faster than original API calls +- **Cost Savings**: Fewer API calls to expensive LLM providers +- **Improved Reliability**: Cached responses available even if provider is down +- **High Throughput**: Asynchronous caching doesn't impact response times +- **Perfect Stream Fidelity**: Cached streams maintain exact chunk ordering and content + +## Error Handling + +The plugin is designed to fail gracefully: + +- If Redis is unavailable during startup, plugin creation fails with clear error +- If Redis becomes unavailable during operation, requests continue without caching +- If cache retrieval fails, requests proceed normally +- If cache storage fails asynchronously, responses are unaffected (already returned) +- Malformed cached data is ignored and requests proceed normally +- Cache operations have timeouts to prevent resource leaks + +## Best Practices + +1. **Start Simple**: Use only `Addr` and `CacheKey` - let defaults handle the rest +2. **Choose meaningful cache keys**: Use descriptive context keys that identify cache sessions +3. **Set appropriate TTL**: Balance between cache efficiency and data freshness +4. **Use meaningful prefixes**: Helps organize cache keys in shared Redis instances +5. **Monitor Redis memory**: Track cache usage, especially for streaming responses (more chunks = more storage) +6. **Context management**: Always provide cache key in request context for caching to work +7. **Use `bifrost.Ptr()`**: For boolean pointer configuration options +8. **Streaming considerations**: Longer streams create more cache entries, adjust TTL accordingly + +## Security Considerations + +- **Sensitive Data**: Be cautious about caching responses containing sensitive information +- **Redis Security**: Use authentication and network security for Redis +- **Data Isolation**: Use different Redis databases or prefixes for different environments diff --git a/plugins/redis/go.mod b/plugins/redis/go.mod new file mode 100644 index 0000000000..cab7d3e384 --- /dev/null +++ b/plugins/redis/go.mod @@ -0,0 +1,50 @@ +module github.com/maximhq/bifrost/plugins/redis + +go 1.24.1 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 + github.com/maximhq/bifrost/core v1.1.16 + github.com/redis/go-redis/v9 v9.10.0 +) + +replace github.com/maximhq/bifrost/core => ../../core + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.60.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/plugins/redis/go.sum b/plugins/redis/go.sum new file mode 100644 index 0000000000..47805f9be8 --- /dev/null +++ b/plugins/redis/go.sum @@ -0,0 +1,122 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs= +github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= +github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/plugins/redis/main.go b/plugins/redis/main.go new file mode 100644 index 0000000000..1a053b73ba --- /dev/null +++ b/plugins/redis/main.go @@ -0,0 +1,599 @@ +// Package redis provides Redis caching integration for Bifrost plugin. +// This plugin caches request body hashes using xxhash and returns cached responses for identical requests. +// It supports configurable caching behavior including success-only caching and custom cache key generation. +package redis + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "time" + + "github.com/cespare/xxhash/v2" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/redis/go-redis/v9" +) + +// RedisPluginConfig contains configuration for the Redis plugin. +// All Redis client options are passed directly to the Redis client, which handles its own defaults. +// Only specify values you want to override from Redis client defaults. +type RedisPluginConfig struct { + // Connection settings + Addr string `json:"addr"` // Redis server address (host:port) - REQUIRED + Username string `json:"username,omitempty"` // Username for Redis AUTH (optional) + Password string `json:"password,omitempty"` // Password for Redis AUTH (optional) + DB int `json:"db,omitempty"` // Redis database number (default: 0) + CacheKey string `json:"cache_key"` // Cache key for context lookup - REQUIRED + CacheTTLKey string `json:"cache_ttl_key"` // Cache TTL key for context lookup (optional) + + // Connection pool and timeout settings (passed directly to Redis client) + PoolSize int `json:"pool_size,omitempty"` // Maximum number of socket connections (optional) + MinIdleConns int `json:"min_idle_conns,omitempty"` // Minimum number of idle connections (optional) + MaxIdleConns int `json:"max_idle_conns,omitempty"` // Maximum number of idle connections (optional) + ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty"` // Connection maximum lifetime (optional) + ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty"` // Connection maximum idle time (optional) + DialTimeout time.Duration `json:"dial_timeout,omitempty"` // Timeout for socket connection (optional) + ReadTimeout time.Duration `json:"read_timeout,omitempty"` // Timeout for socket reads (optional) + WriteTimeout time.Duration `json:"write_timeout,omitempty"` // Timeout for socket writes (optional) + ContextTimeout time.Duration `json:"context_timeout,omitempty"` // Timeout for Redis operations (optional) + + // Plugin behavior settings + TTL time.Duration `json:"ttl,omitempty"` // Time-to-live for cached responses (default: 5min) + Prefix string `json:"prefix,omitempty"` // Prefix for cache keys (optional) + + // Advanced caching behavior + CacheByModel *bool `json:"cache_by_model,omitempty"` // Include model in cache key (default: true) + CacheByProvider *bool `json:"cache_by_provider,omitempty"` // Include provider in cache key (default: true) +} + +// Plugin implements the schemas.Plugin interface for Redis caching. +// It caches responses based on xxhash of normalized requests and returns cached +// responses for identical requests. The plugin supports configurable caching behavior +// including success-only caching and custom cache key generation. +// +// Fields: +// - client: Redis client instance for cache operations +// - config: Plugin configuration including Redis and caching settings +// - logger: Logger instance for plugin operations +type Plugin struct { + client *redis.Client + config RedisPluginConfig + logger schemas.Logger +} + +const ( + PluginName string = "bifrost-redis" + PluginLoggerPrefix string = "[Bifrost Redis Plugin]" + RedisConnectionTimeout time.Duration = 5 * time.Second + RedisCacheSetTimeout time.Duration = 30 * time.Second +) + +// NewRedisPlugin creates a new Redis plugin instance with the provided configuration. +// It establishes a connection to Redis, tests connectivity, and returns a configured plugin. +// +// All Redis client options are passed directly to the Redis client, which handles its own defaults. +// The plugin only sets defaults for its own behavior (TTL, CacheOnlySuccessful, etc.). +// +// Parameters: +// - config: Redis and plugin configuration (only Addr is required) +// - logger: Logger instance for the plugin +// +// Returns: +// - schemas.Plugin: A configured Redis plugin instance +// - error: Any error that occurred during plugin initialization or Redis connection +func NewRedisPlugin(config RedisPluginConfig, logger schemas.Logger) (schemas.Plugin, error) { + // Validate required fields + if config.Addr == "" { + return nil, fmt.Errorf("redis address is required") + } + + if config.CacheKey == "" { + return nil, fmt.Errorf("cache key is required") + } + + // Set plugin-specific defaults (not Redis defaults) + if config.TTL == 0 { + logger.Debug(PluginLoggerPrefix + " TTL is not set, using default of 5 minutes") + config.TTL = 5 * time.Minute + } + if config.ContextTimeout == 0 { + config.ContextTimeout = 10 * time.Second // Only for our ping test + } + + // Set cache behavior defaults + if config.CacheByModel == nil { + config.CacheByModel = bifrost.Ptr(true) + } + if config.CacheByProvider == nil { + config.CacheByProvider = bifrost.Ptr(true) + } + + // Create Redis client with all provided options + opts := &redis.Options{ + Addr: config.Addr, + Username: config.Username, + Password: config.Password, + DB: config.DB, + PoolSize: config.PoolSize, + MinIdleConns: config.MinIdleConns, + MaxIdleConns: config.MaxIdleConns, + ConnMaxLifetime: config.ConnMaxLifetime, + ConnMaxIdleTime: config.ConnMaxIdleTime, + DialTimeout: config.DialTimeout, + ReadTimeout: config.ReadTimeout, + WriteTimeout: config.WriteTimeout, + } + + // Create Redis client + client := redis.NewClient(opts) + + // Test connection with configured timeout + ctx, cancel := context.WithTimeout(context.Background(), RedisConnectionTimeout) + defer cancel() + + _, err := client.Ping(ctx).Result() + if err != nil { + client.Close() + return nil, fmt.Errorf("failed to ping Redis at %s: %w", config.Addr, err) + } + + logger.Info(fmt.Sprintf("%s Successfully connected to Redis at %s", PluginLoggerPrefix, config.Addr)) + + return &Plugin{ + client: client, + config: config, + logger: logger, + }, nil +} + +// generateRequestHash creates an xxhash of the request for caching. +// It normalizes the request by including only the relevant fields based on configuration: +// - Provider (if CacheByProvider is true) +// - Model (if CacheByModel is true) +// - Input (chat completion or text completion) +// - Parameters (all parameters are included) +// +// Note: Fallbacks are excluded as they only affect error handling, not the actual response. +// +// Parameters: +// - req: The Bifrost request to hash +// +// Returns: +// - string: Hexadecimal representation of the xxhash +// - error: Any error that occurred during request normalization or hashing +func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest, cacheKey string) (string, error) { + // Create a normalized request for hashing + // Note: Fallbacks are excluded as they only affect error handling, not the actual response + normalizedReq := struct { + Provider schemas.ModelProvider `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Input schemas.RequestInput `json:"input"` + Params *schemas.ModelParameters `json:"params,omitempty"` + }{ + Input: req.Input, + } + + // Include provider and model based on configuration + if plugin.config.CacheByProvider != nil && *plugin.config.CacheByProvider { + normalizedReq.Provider = req.Provider + } + if plugin.config.CacheByModel != nil && *plugin.config.CacheByModel { + normalizedReq.Model = req.Model + } + + // Include all parameters in cache key + normalizedReq.Params = req.Params + + // Marshal to JSON for consistent hashing + jsonData, err := json.Marshal(normalizedReq) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + // Generate hash based on configured algorithm + hash := xxhash.Sum64(jsonData) + return fmt.Sprintf("%s_%x", cacheKey, hash), nil +} + +// ContextKey is a custom type for context keys to prevent key collisions +type ContextKey string + +const ( + requestHashKey ContextKey = "redis_request_hash" + isCacheHitKey ContextKey = "redis_is_cache_hit" +) + +// GetName returns the canonical name of the Redis plugin. +// This name is used for plugin identification and logging purposes. +// +// Returns: +// - string: The plugin name "bifrost-redis" +func (p *Plugin) GetName() string { + return PluginName +} + +// PreHook is called before a request is processed by Bifrost. +// It checks if a cached response exists for the request hash and returns it if found. +// +// Parameters: +// - ctx: Pointer to the context.Context +// - req: The incoming Bifrost request +// +// Returns: +// - *schemas.BifrostRequest: The original request +// - *schemas.BifrostResponse: Cached response if found, nil otherwise +// - error: Any error that occurred during cache lookup +func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Get the cache key from the context + var cacheKey string + var ok bool + if ctx != nil { + cacheKey, ok = (*ctx).Value(ContextKey(plugin.config.CacheKey)).(string) + if !ok || cacheKey == "" { + return req, nil, nil + } + } else { + return req, nil, nil + } + + // Generate hash for the request + hash, err := plugin.generateRequestHash(req, cacheKey) + if err != nil { + // If we can't generate hash, just continue without caching + plugin.logger.Debug(PluginLoggerPrefix + " Failed to generate request hash, continuing without caching") + return req, nil, nil + } + + // Store hash in context for PostHook + *ctx = context.WithValue(*ctx, requestHashKey, hash) + + requestTypeValue := (*ctx).Value(bifrost.BifrostContextKeyRequestType) + if requestTypeValue == nil { + plugin.logger.Debug(PluginLoggerPrefix + " No request type found in context, continuing without caching") + return req, nil, nil + } + requestType, ok := requestTypeValue.(bifrost.RequestType) + if !ok { + plugin.logger.Debug(PluginLoggerPrefix + " Request type is not a bifrost.RequestType, continuing without caching") + return req, nil, nil + } + + // Create cache key + cacheKey = plugin.config.Prefix + hash + + if plugin.isStreamingRequest(requestType) { + // For streaming requests, find all chunks and create a stream + chunkPattern := cacheKey + "_chunk_*" + + // Get all chunk keys matching the pattern using SCAN + var chunkKeys []string + var cursor uint64 + for { + batch, c, err := plugin.client.Scan(*ctx, cursor, chunkPattern, 1000).Result() + if err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to scan cached chunks, continuing with request") + return req, nil, nil + } + chunkKeys = append(chunkKeys, batch...) + cursor = c + if cursor == 0 { + break + } + } + + if len(chunkKeys) == 0 { + plugin.logger.Debug(PluginLoggerPrefix + " No cached chunks found, continuing with request") + return req, nil, nil + } + + plugin.logger.Info(fmt.Sprintf("%s Found %d cached chunks for request %s, returning stream", PluginLoggerPrefix, len(chunkKeys), cacheKey)) + + // Create stream channel + streamChan := make(chan *schemas.BifrostStream) + + go func() { + defer close(streamChan) + + // Get all chunk data + chunkData, err := plugin.client.MGet(*ctx, chunkKeys...).Result() + if err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to retrieve cached chunks") + return + } + + var chunks []schemas.BifrostResponse + for _, data := range chunkData { + if data == nil { + continue + } + + // Unmarshal cached response + var cachedResponse schemas.BifrostResponse + if err := json.Unmarshal([]byte(data.(string)), &cachedResponse); err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to unmarshal cached chunk, skipping") + continue + } + + chunks = append(chunks, cachedResponse) + } + + // Sort chunks by index + sort.Slice(chunks, func(i, j int) bool { + return chunks[i].ExtraFields.ChunkIndex < chunks[j].ExtraFields.ChunkIndex + }) + + // Send chunks in order + for _, chunk := range chunks { + if chunk.ExtraFields.RawResponse == nil { + chunk.ExtraFields.RawResponse = make(map[string]interface{}) + } + if rawResponseMap, ok := chunk.ExtraFields.RawResponse.(map[string]interface{}); ok { + rawResponseMap["bifrost_cached"] = true + rawResponseMap["bifrost_cache_key"] = fmt.Sprintf("%s_chunk_%d", cacheKey, chunk.ExtraFields.ChunkIndex) + } + + chunk.ExtraFields.Provider = req.Provider + + streamChan <- &schemas.BifrostStream{ + BifrostResponse: &chunk, + } + } + }() + + *ctx = context.WithValue(*ctx, isCacheHitKey, true) + + // Return short-circuit with stream + return req, &schemas.PluginShortCircuit{ + Stream: streamChan, + }, nil + + } else { + // Check if cached response exists + cachedData, err := plugin.client.Get(*ctx, cacheKey).Result() + if err != nil { + if err == redis.Nil { + plugin.logger.Debug(PluginLoggerPrefix + " No cached response found, continuing with request") + // No cached response found, continue with normal processing + return req, nil, nil + } + // Log error but continue processing + plugin.logger.Warn(PluginLoggerPrefix + " Failed to get cached response, continuing without caching") + return req, nil, nil + } + + // Unmarshal cached response + var cachedResponse schemas.BifrostResponse + if err := json.Unmarshal([]byte(cachedData), &cachedResponse); err != nil { + // If we can't unmarshal, just continue without cached response + plugin.logger.Warn(PluginLoggerPrefix + " Failed to unmarshal cached response, continuing without caching") + return req, nil, nil + } + + plugin.logger.Debug(fmt.Sprintf("%s Found cached response for request %s, returning it", PluginLoggerPrefix, cacheKey)) + + // Mark response as cached in extra fields + if cachedResponse.ExtraFields.RawResponse == nil { + cachedResponse.ExtraFields.RawResponse = make(map[string]interface{}) + } + if rawResponseMap, ok := cachedResponse.ExtraFields.RawResponse.(map[string]interface{}); ok { + rawResponseMap["bifrost_cached"] = true + rawResponseMap["bifrost_cache_key"] = cacheKey + } + cachedResponse.ExtraFields.Provider = req.Provider + + *ctx = context.WithValue(*ctx, isCacheHitKey, true) + + // Return cached response + return req, &schemas.PluginShortCircuit{ + Response: &cachedResponse, + }, nil + } + +} + +// PostHook is called after a response is received from a provider. +// It caches the response using the request hash as the key, with optional filtering +// based on the CacheOnlySuccessful configuration. +// +// The function performs the following operations: +// 1. Checks if CacheOnlySuccessful is enabled and skips caching for unsuccessful responses +// 2. Retrieves the request hash from the context (set during PreHook) +// 3. Marshals the response for storage +// 4. Stores the response in Redis asynchronously (non-blocking) +// +// The Redis SET operation runs in a separate goroutine to avoid blocking the response. +// The function gracefully handles errors and continues without caching if any step fails, +// ensuring that response processing is never interrupted by caching issues. +// +// Parameters: +// - ctx: Pointer to the context.Context containing the request hash +// - res: The response from the provider to be cached +// - bifrostErr: The error from the provider, if any (used for success determination) +// +// Returns: +// - *schemas.BifrostResponse: The original response, unmodified +// - *schemas.BifrostError: The original error, unmodified +// - error: Any error that occurred during caching preparation (always nil as errors are handled gracefully) +func (plugin *Plugin) PostHook(ctx *context.Context, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if bifrostErr != nil { + return res, bifrostErr, nil + } + + isCacheHit := (*ctx).Value(isCacheHitKey) + if isCacheHit != nil { + isCacheHitValue, ok := isCacheHit.(bool) + if ok && isCacheHitValue { + // If the cache hit is true, we should not cache + return res, nil, nil + } + } + + // Get the request type from context + requestTypeValue := (*ctx).Value(bifrost.BifrostContextKeyRequestType) + if requestTypeValue == nil { + plugin.logger.Debug(PluginLoggerPrefix + " No request type found in context, continuing without caching") + return res, nil, nil + } + + requestType, ok := requestTypeValue.(bifrost.RequestType) + if !ok { + plugin.logger.Debug(PluginLoggerPrefix + " Request type is not a bifrost.RequestType, continuing without caching") + return res, nil, nil + } + + // Get the hash from context + hashValue := (*ctx).Value(requestHashKey) + if hashValue == nil { + // If we don't have the hash, we can't cache (expected when cache key is not present) + return res, nil, nil + } + + hash, ok := hashValue.(string) + if !ok { + plugin.logger.Debug(PluginLoggerPrefix + " Hash is not a string, continuing without caching") + return res, nil, nil + } + + cacheTTL := plugin.config.TTL + + // Get the request TTL from the context + ttlValue := (*ctx).Value(ContextKey(plugin.config.CacheTTLKey)) + if ttlValue != nil { + ttl, ok := ttlValue.(time.Duration) + if !ok { + plugin.logger.Debug(PluginLoggerPrefix + " TTL is not a time.Duration, using default TTL") + } else { + cacheTTL = ttl + } + } + + // Create cache key + cacheKey := plugin.config.Prefix + hash + + // Add "chunk_{index}" to the cache key for streaming responses + if plugin.isStreamingRequest(requestType) { + cacheKey = fmt.Sprintf("%s_chunk_%d", cacheKey, res.ExtraFields.ChunkIndex) + } + + // Cache the response asynchronously to avoid blocking the response + go func() { + // Create a background context with timeout for the cache operation + // This ensures the cache operation doesn't run indefinitely + cacheCtx, cancel := context.WithTimeout(context.Background(), RedisCacheSetTimeout) + defer cancel() + + // Marshal response for caching + responseData, err := json.Marshal(res) + if err != nil { + // If we can't marshal, just return the response without caching + plugin.logger.Warn(PluginLoggerPrefix + " Failed to marshal response, continuing without caching") + return + } + + // Perform the Redis SET operation + err = plugin.client.Set(cacheCtx, cacheKey, responseData, cacheTTL).Err() + if err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to cache response asynchronously: " + err.Error()) + } else { + plugin.logger.Debug(fmt.Sprintf("%s Cached response for request %s", PluginLoggerPrefix, cacheKey)) + } + }() + + return res, nil, nil +} + +// Cleanup performs cleanup operations for the Redis plugin. +// It removes all cached entries with the configured prefix and closes the Redis connection. +// +// The function performs the following operations: +// 1. Retrieves all cache keys matching the configured prefix pattern +// 2. Deletes all matching cache entries from Redis +// 3. Closes the Redis client connection +// +// This method should be called when shutting down the application to ensure +// proper resource cleanup and prevent connection leaks. +// +// Returns: +// - error: Any error that occurred during cleanup operations +func (plugin *Plugin) Cleanup() error { + // Get all keys matching the prefix using SCAN + var keys []string + var cursor uint64 + pattern := plugin.config.Prefix + "*" + for { + batch, c, err := plugin.client.Scan(context.Background(), cursor, pattern, 1000).Result() + if err != nil { + return fmt.Errorf("failed to scan keys for cleanup: %w", err) + } + keys = append(keys, batch...) + cursor = c + if cursor == 0 { + break + } + } + + if len(keys) > 0 { + if err := plugin.client.Del(context.Background(), keys...).Err(); err != nil { + return fmt.Errorf("failed to delete cache keys: %w", err) + } + plugin.logger.Debug(fmt.Sprintf("%s Cleaned up %d cache entries", PluginLoggerPrefix, len(keys))) + } + + if err := plugin.client.Close(); err != nil { + return fmt.Errorf("failed to close Redis client: %w", err) + } + + plugin.logger.Debug(PluginLoggerPrefix + " Successfully closed Redis connection") + return nil +} + +// ClearCacheForKey deletes a specific cache key from Redis. +// It is used to clear a specific cache key when needed. +// +// Parameters: +// - key: The cache key to delete +// +// Returns: +// - error: Any error that occurred during cache key deletion +func (plugin *Plugin) ClearCacheForKey(key string) error { + var keys []string + keys = append(keys, key) + + // For streaming requests, we need to delete all chunks for the key + chunkPattern := key + "_chunk_*" + + // Get all chunk keys matching the pattern using SCAN + var chunkKeys []string + var cursor uint64 + for { + batch, c, err := plugin.client.Scan(context.Background(), cursor, chunkPattern, 1000).Result() + if err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to scan cached chunks, continuing with request") + return err + } + chunkKeys = append(chunkKeys, batch...) + cursor = c + if cursor == 0 { + break + } + } + + keys = append(keys, chunkKeys...) + + if err := plugin.client.Del(context.Background(), keys...).Err(); err != nil { + plugin.logger.Warn(PluginLoggerPrefix + " Failed to get cached chunks, continuing with request") + return err + } + + return nil +} + +// UTILS FUNCTIONS + +func (plugin *Plugin) isStreamingRequest(requestType bifrost.RequestType) bool { + return requestType == bifrost.ChatCompletionStreamRequest || + requestType == bifrost.SpeechStreamRequest || + requestType == bifrost.TranscriptionStreamRequest +} diff --git a/plugins/redis/plugin_test.go b/plugins/redis/plugin_test.go new file mode 100644 index 0000000000..4c54956df8 --- /dev/null +++ b/plugins/redis/plugin_test.go @@ -0,0 +1,435 @@ +package redis + +import ( + "context" + "os" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/redis/go-redis/v9" +) + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the Maxim plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +// Currently only supports OpenAI for simplicity in testing. You are free to add more providers as needed. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +const ( + TestCacheKey = "x-test-cache-key" + TestPrefix = "test_redis_plugin_" +) + +// GetKeysForProvider returns a mock API key configuration for testing. +// Uses the OPENAI_API_KEY environment variable for authentication. +func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +// Uses standard network and concurrency settings. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// clearTestKeysWithPrefix removes all Redis keys matching the test prefix using SCAN. +// This is safer than FLUSHALL as it only affects test keys, not the entire Redis instance. +func clearTestKeysWithPrefix(t *testing.T, client *redis.Client, prefix string) { + ctx := context.Background() + pattern := prefix + "*" + + var keys []string + var cursor uint64 + + // Use SCAN to find all keys matching the prefix + for { + batch, c, err := client.Scan(ctx, cursor, pattern, 1000).Result() + if err != nil { + t.Logf("Warning: Failed to scan keys with prefix %s: %v", prefix, err) + return + } + keys = append(keys, batch...) + cursor = c + if cursor == 0 { + break + } + } + + // Delete keys in batches if any were found + if len(keys) > 0 { + if err := client.Del(ctx, keys...).Err(); err != nil { + t.Logf("Warning: Failed to delete test keys: %v", err) + } else { + t.Logf("Cleaned up %d test keys with prefix %s", len(keys), prefix) + } + } +} + +func TestRedisPlugin(t *testing.T) { + // Configure plugin with minimal Redis connection settings (only Addr is required) + config := RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: TestCacheKey, + Prefix: TestPrefix, // Use test-specific prefix to isolate test data + // Optional: add password if your Redis instance requires it + Password: os.Getenv("REDIS_PASSWORD"), + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + + // Initialize the Redis plugin (it will create its own client) + plugin, err := NewRedisPlugin(config, logger) + if err != nil { + t.Skipf("Redis not available or failed to connect: %v", err) + return + } + + // Get the internal client for test setup (we need to type assert to access it) + pluginImpl := plugin.(*Plugin) + redisClient := pluginImpl.client + + // Clear test keys before test (safer than FLUSHALL) + clearTestKeysWithPrefix(t, redisClient, TestPrefix) + ctx := context.Background() + + account := BaseAccount{} + + ctx = context.WithValue(ctx, ContextKey(TestCacheKey), "test-value") + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: logger, + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + // Create a test request + testRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What is Bifrost? Answer in one short sentence."), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.7), + MaxTokens: bifrost.Ptr(50), + }, + } + + t.Log("Making first request (should go to OpenAI and be cached)...") + + // Make first request (will go to OpenAI and be cached) + start1 := time.Now() + response1, bifrostErr1 := client.ChatCompletionRequest(ctx, testRequest) + duration1 := time.Since(start1) + + if bifrostErr1 != nil { + t.Fatalf("First request failed: %v", bifrostErr1) + } + + if response1 == nil { + t.Fatal("First response is nil") + } + + if len(response1.Choices) == 0 { + t.Fatal("First response has no choices") + } + + if response1.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("First response content is nil") + } + + t.Logf("First request completed in %v", duration1) + t.Logf("Response: %s", *response1.Choices[0].Message.Content.ContentStr) + + // Wait a moment to ensure cache is written + time.Sleep(100 * time.Millisecond) + + t.Log("Making second identical request (should be served from cache)...") + + // Make second identical request (should be cached) + // Use the same context with cache key for the second request + start2 := time.Now() + response2, bifrostErr2 := client.ChatCompletionRequest(ctx, testRequest) + duration2 := time.Since(start2) + + if bifrostErr2 != nil { + t.Fatalf("Second request failed: %v", bifrostErr2) + } + + if response2 == nil { + t.Fatal("Second response is nil") + } + + if len(response2.Choices) == 0 { + t.Fatal("Second response has no choices") + } + + if response2.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Second response content is nil") + } + + t.Logf("Second request completed in %v", duration2) + t.Logf("Response: %s", *response2.Choices[0].Message.Content.ContentStr) + + // Check if second request was cached + cached := false + var cacheKeyValue interface{} + + if response2.ExtraFields.RawResponse == nil { + t.Error("Second response ExtraFields.RawResponse is nil - expected cache metadata") + } else { + rawMap, ok := response2.ExtraFields.RawResponse.(map[string]interface{}) + if !ok { + t.Error("Second response ExtraFields.RawResponse is not a map - expected cache metadata") + } else { + cachedFlag, exists := rawMap["bifrost_cached"] + if !exists { + t.Error("Second response missing 'bifrost_cached' flag - expected cache hit") + } else { + cachedBool, ok := cachedFlag.(bool) + if !ok { + t.Error("'bifrost_cached' flag is not a boolean") + } else if !cachedBool { + t.Error("'bifrost_cached' flag is false - expected cache hit") + } else { + cached = true + t.Log("Second request was served from Redis cache!") + + cacheKeyValue, exists = rawMap["bifrost_cache_key"] + if !exists { + t.Error("Cache metadata missing 'bifrost_cache_key'") + } else { + t.Logf("Cache key: %v", cacheKeyValue) + } + } + } + } + } + + // Performance comparison + t.Logf("Performance Summary:") + t.Logf("First request (OpenAI): %v", duration1) + t.Logf("Second request (Cache): %v", duration2) + + if !cached { + t.Fatal("Second request was not cached - cache functionality is not working") + } + + if duration2 >= duration1 { + t.Errorf("Cache request took longer than original request: cache=%v, original=%v", duration2, duration1) + } else { + speedup := float64(duration1) / float64(duration2) + t.Logf("Cache speedup: %.2fx faster", speedup) + + // Assert that cache is at least 2x faster (reasonable expectation) + if speedup < 2.0 { + t.Errorf("Cache speedup is less than 2x: got %.2fx", speedup) + } + } + + // Verify responses are identical (content should be the same) + content1 := *response1.Choices[0].Message.Content.ContentStr + content2 := *response2.Choices[0].Message.Content.ContentStr + + if content1 != content2 { + t.Errorf("Response content differs between cached and original:\nOriginal: %s\nCached: %s", content1, content2) + } else { + t.Log("Both responses have identical content") + } + + // Verify provider information is maintained in cached response + // The cached response should have the provider set, while the original might not + if response2.ExtraFields.Provider != testRequest.Provider { + t.Errorf("Provider mismatch in cached response: expected %s, got %s", + testRequest.Provider, response2.ExtraFields.Provider) + } + + t.Log("Redis caching test completed successfully!") + t.Log("The Redis plugin successfully cached the response and served it faster on the second request.") +} + +func TestRedisPluginStreaming(t *testing.T) { + // Configure plugin with minimal Redis connection settings + config := RedisPluginConfig{ + Addr: "localhost:6379", + CacheKey: TestCacheKey, + Prefix: TestPrefix, // Use test-specific prefix to isolate test data + Password: os.Getenv("REDIS_PASSWORD"), + } + + logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) + + // Initialize the Redis plugin + plugin, err := NewRedisPlugin(config, logger) + if err != nil { + t.Skipf("Redis not available or failed to connect: %v", err) + return + } + + // Get the internal client for test setup + pluginImpl := plugin.(*Plugin) + redisClient := pluginImpl.client + + // Clear test keys before test (safer than FLUSHALL) + clearTestKeysWithPrefix(t, redisClient, TestPrefix) + ctx := context.Background() + + account := BaseAccount{} + ctx = context.WithValue(ctx, ContextKey(TestCacheKey), "test-stream-value") + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: logger, + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + // Create a test streaming request + testRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Count from 1 to 3, each number on a new line."), + }, + }, + }, + }, + Params: &schemas.ModelParameters{ + Temperature: bifrost.Ptr(0.0), // Use 0 temperature for more predictable responses + MaxTokens: bifrost.Ptr(20), + }, + } + + t.Log("Making first streaming request (should go to OpenAI and be cached)...") + + // Make first streaming request + start1 := time.Now() + stream1, bifrostErr1 := client.ChatCompletionStreamRequest(ctx, testRequest) + if bifrostErr1 != nil { + t.Fatalf("First streaming request failed: %v", bifrostErr1) + } + + var responses1 []schemas.BifrostResponse + for streamMsg := range stream1 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in first stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + responses1 = append(responses1, *streamMsg.BifrostResponse) + } + } + duration1 := time.Since(start1) + + if len(responses1) == 0 { + t.Fatal("First streaming request returned no responses") + } + + t.Logf("First streaming request completed in %v with %d chunks", duration1, len(responses1)) + + // Wait for cache to be written + time.Sleep(200 * time.Millisecond) + + t.Log("Making second identical streaming request (should be served from cache)...") + + // Make second identical streaming request + start2 := time.Now() + stream2, bifrostErr2 := client.ChatCompletionStreamRequest(ctx, testRequest) + if bifrostErr2 != nil { + t.Fatalf("Second streaming request failed: %v", bifrostErr2) + } + + var responses2 []schemas.BifrostResponse + for streamMsg := range stream2 { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in second stream: %v", streamMsg.BifrostError) + } + if streamMsg.BifrostResponse != nil { + responses2 = append(responses2, *streamMsg.BifrostResponse) + } + } + duration2 := time.Since(start2) + + if len(responses2) == 0 { + t.Fatal("Second streaming request returned no responses") + } + + t.Logf("Second streaming request completed in %v with %d chunks", duration2, len(responses2)) + + // Validate that both streams have the same number of chunks + if len(responses1) != len(responses2) { + t.Errorf("Stream chunk count mismatch: original=%d, cached=%d", len(responses1), len(responses2)) + } + + // Validate that the second stream was cached + cached := false + for _, response := range responses2 { + if response.ExtraFields.RawResponse != nil { + if rawMap, ok := response.ExtraFields.RawResponse.(map[string]interface{}); ok { + if cachedFlag, exists := rawMap["bifrost_cached"]; exists { + if cachedBool, ok := cachedFlag.(bool); ok && cachedBool { + cached = true + break + } + } + } + } + } + + if !cached { + t.Fatal("Second streaming request was not served from cache") + } + + // Validate performance improvement + if duration2 >= duration1 { + t.Errorf("Cached stream took longer than original: cache=%v, original=%v", duration2, duration1) + } else { + speedup := float64(duration1) / float64(duration2) + t.Logf("Streaming cache speedup: %.2fx faster", speedup) + } + + // Validate chunk ordering is maintained + for i := range responses2 { + if responses2[i].ExtraFields.ChunkIndex != responses1[i].ExtraFields.ChunkIndex { + t.Errorf("Chunk index mismatch at position %d: original=%d, cached=%d", + i, responses1[i].ExtraFields.ChunkIndex, responses2[i].ExtraFields.ChunkIndex) + } + } + + t.Log("Redis streaming cache test completed successfully!") +}