diff --git a/docs/media/plugins/circuit-breaker-states.png b/docs/media/plugins/circuit-breaker-states.png new file mode 100644 index 0000000000..b5879d7bef Binary files /dev/null and b/docs/media/plugins/circuit-breaker-states.png differ diff --git a/plugins/circuitbreaker/README.md b/plugins/circuitbreaker/README.md new file mode 100644 index 0000000000..c8fed3d0d5 --- /dev/null +++ b/plugins/circuitbreaker/README.md @@ -0,0 +1,270 @@ +# Bifrost Circuit Breaker Plugin + +The Circuit Breaker plugin for Bifrost provides automatic failure detection and recovery for AI provider requests. It monitors request failures and slow calls, automatically opening the circuit when thresholds are exceeded to prevent cascading failures. + +## Quick Start + +### Download the Plugin + + ```bash + go get github.com/maximhq/bifrost/plugins/circuitbreaker + ``` + +### Basic Usage + +```go +package main + +import ( + "context" + "time" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + circuitbreaker "github.com/maximhq/bifrost/plugins/circuitbreaker" +) + +func main() { + // Create plugin with default configuration + circuitbreakerPlugin, err := circuitbreaker.NewCircuitBreakerPlugin(circuitbreaker.CircuitBreakerConfig{ + FailureRateThreshold: 0.5, // 50% failure rate threshold + SlowCallRateThreshold: 0.5, // 50% slow call rate threshold + SlowCallDurationThreshold: 5 * time.Second, + MinimumNumberOfCalls: 10, + SlidingWindowType: circuitbreaker.CountBased, // Track last N calls + SlidingWindowSize: 100, // Track last 100 calls + PermittedNumberOfCallsInHalfOpenState: 5, + MaxWaitDurationInHalfOpenState: 60 * time.Second, + }) + if err != nil { + panic(err) + } + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{circuitbreakerPlugin}, + }) + if err != nil { + panic(err) + } + defer client.Cleanup() + + // Circuit breaker will automatically protect your requests + response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello!"), + }, + }, + }, + }, + }) +} +``` + +### State Diagram of Circuit Breaker +![Circuit Breaker States](../../docs/media/plugins/circuit-breaker-states.png) + +## Configuration + +### CircuitBreakerConfig + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `FailureRateThreshold` | `float64` | `0.5` | Failure rate threshold (0.0 to 1.0) | +| `SlowCallRateThreshold` | `float64` | `0.5` | Slow call rate threshold (0.0 to 1.0) | +| `SlowCallDurationThreshold` | `time.Duration` | `5s` | Duration threshold for slow calls | +| `MinimumNumberOfCalls` | `int` | `10` | Minimum calls before evaluation | +| `SlidingWindowType` | `string` | `"count-based"` | `"count-based"` or `"time-based"` | +| `SlidingWindowSize` | `int` | `100` | Size of sliding window (calls for count-based, seconds for time-based) | +| `PermittedNumberOfCallsInHalfOpenState` | `int` | `5` | Calls allowed in half-open state | +| `MaxWaitDurationInHalfOpenState` | `time.Duration` | `60s` | Wait time before half-open transition | +| `Logger` | `schemas.Logger` | `bifrost.NewDefaultLogger(schemas.LogLevelInfo)` | Logger for circuit breaker operations | + +### Sliding Window Types + +The circuit breaker supports two types of sliding windows for collecting metrics: + +#### Count-Based Sliding Window +- **Type**: `"count-based"` +- **Size**: Number of most recent calls to track +- **Behavior**: Maintains a fixed-size circular buffer of the last N calls +- **Use Case**: When you want to evaluate based on a specific number of recent requests +- **Example**: Track the last 100 calls to evaluate failure rates + +#### Time-Based Sliding Window +- **Type**: `"time-based"` +- **Size**: Duration in seconds to look back +- **Behavior**: Maintains all calls within the specified time window +- **Use Case**: When you want to evaluate based on a time period +- **Example**: Track all calls in the last 5 minutes to evaluate failure rates + +### Configuration Examples + +#### Count-Based Sliding Window (Default) + +```go +config := circuitbreaker.CircuitBreakerConfig{ + FailureRateThreshold: 0.3, // 30% failure rate threshold + SlowCallRateThreshold: 0.4, // 40% slow call rate threshold + SlowCallDurationThreshold: 10 * time.Second, + MinimumNumberOfCalls: 20, + SlidingWindowType: circuitbreaker.CountBased, // Track last N calls + SlidingWindowSize: 200, // Track last 200 calls + PermittedNumberOfCallsInHalfOpenState: 3, + MaxWaitDurationInHalfOpenState: 30 * time.Second, +} +``` + +#### Time-Based Sliding Window + +```go +config := circuitbreaker.CircuitBreakerConfig{ + FailureRateThreshold: 0.3, // 30% failure rate threshold + SlowCallRateThreshold: 0.4, // 40% slow call rate threshold + SlowCallDurationThreshold: 10 * time.Second, + MinimumNumberOfCalls: 20, + SlidingWindowType: circuitbreaker.TimeBased, // Track calls in time window + SlidingWindowSize: 300, // Track calls in last 300 seconds (5 minutes) + PermittedNumberOfCallsInHalfOpenState: 3, + MaxWaitDurationInHalfOpenState: 30 * time.Second, +} +``` + +### Logging Configuration + +The circuit breaker plugin includes comprehensive logging to help you monitor its behavior. By default, it uses Bifrost's default logger with `Info` level logging. You can customize the logger by providing your own implementation: + +```go +// Use custom logger +customLogger := yourCustomLoggerImplementation +config := circuitbreaker.CircuitBreakerConfig{ + FailureRateThreshold: 0.3, + // ... other config options + Logger: customLogger, // Use your custom logger +} + +// Or use Bifrost's default logger with different log level +config := circuitbreaker.CircuitBreakerConfig{ + FailureRateThreshold: 0.3, + // ... other config options + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), // More verbose logging +} +``` + +## Circuit States + +### CLOSED (Normal Operation) +- Requests are sent to providers normally +- Circuit breaker monitors failures and slow calls +- Metrics are collected in sliding window + +### OPEN (Failure Protection) +- All requests are immediately rejected +- No provider calls are made +- Prevents cascading failures +- Automatically transitions to HALF_OPEN after wait duration + +### HALF_OPEN (Recovery Testing) +- Limited number of requests are allowed through +- Success/failure determines next state +- Success → CLOSED (recovery complete) +- Failure → OPEN (still failing) + +### CircuitState Type + +The `CircuitState` type is an `int32` enum that represents the three possible states of the circuit breaker. It includes a `String()` method that provides human-readable string representations: + +- `StateClosed` → `"CLOSED"` +- `StateOpen` → `"OPEN"` +- `StateHalfOpen` → `"HALF_OPEN"` + +This is useful for logging, debugging, and displaying circuit breaker status in monitoring dashboards. + +### Error Classification + +The circuit breaker distinguishes between different types of errors: +- **Server Errors (5xx)**: Considered failures that contribute to the failure rate +- **Rate Limit Errors (429)**: Considered failures that contribute to the failure rate +- **Other Client Errors (4xx)**: Considered successful for circuit breaker purposes (e.g., invalid requests, authentication errors) + +This classification ensures that rate limiting issues and server-side problems trigger circuit breaker protection, while other client-side issues (like invalid API keys or malformed requests) don't. + +## Monitoring + +### Get Circuit State + +```go +state := plugin.GetState(schemas.OpenAI) +switch state { +case circuitbreaker.StateClosed: + fmt.Println("Circuit is CLOSED - normal operation") +case circuitbreaker.StateOpen: + fmt.Println("Circuit is OPEN - requests blocked") +case circuitbreaker.StateHalfOpen: + fmt.Println("Circuit is HALF_OPEN - testing recovery") +} +``` + +### Get Metrics + +```go +metrics, err := plugin.GetMetrics(schemas.OpenAI) +if err == nil { + fmt.Printf("Total Calls: %d\n", metrics.TotalCalls) + fmt.Printf("Failed Calls: %d\n", metrics.FailedCalls) + fmt.Printf("Failure Rate: %.2f%%\n", metrics.FailureRate*100) + fmt.Printf("Slow Call Rate: %.2f%%\n", metrics.SlowCallRate*100) +} +``` + +## Advanced Operations + +### Manual Circuit Control + +The circuit breaker provides manual control functions for testing and emergency situations: + +```go +// Force the circuit to open state (blocks all requests) +err := plugin.ForceOpen(schemas.OpenAI) +if err != nil { + fmt.Printf("Error forcing circuit open: %v\n", err) +} + +// Force the circuit to closed state (allows all requests) +err = plugin.ForceClose(schemas.OpenAI) +if err != nil { + fmt.Printf("Error forcing circuit closed: %v\n", err) +} + +// Reset the circuit breaker (clears all metrics and returns to closed state) +err = plugin.Reset(schemas.OpenAI) +if err != nil { + fmt.Printf("Error resetting circuit: %v\n", err) +} +``` + +**Note**: Manual control should be used sparingly and primarily for testing or emergency situations. The automatic circuit breaker logic is designed to handle most scenarios optimally. + +## Performance + +The Circuit Breaker plugin is optimized for high-performance scenarios: + +- **Atomic Operations**: Uses atomic counters for thread-safe statistics +- **Lock-Free Reads**: Read operations don't block other operations +- **Memory Efficient**: Pre-allocated data structures with minimal allocations + +## Best Practices + +1. **Monitor Metrics**: Regularly check circuit states and failure rates +2. **Adjust Thresholds**: Lower thresholds for critical services, higher for non-critical +3. **Test Recovery**: Verify half-open state works correctly in your environment +4. **Use Fallbacks**: Combine with Bifrost's fallback providers for maximum resilience + +**Need help?** Check the [Bifrost documentation](../../docs/plugins.md) or open an issue on GitHub. + diff --git a/plugins/circuitbreaker/go.mod b/plugins/circuitbreaker/go.mod new file mode 100644 index 0000000000..450c17b58f --- /dev/null +++ b/plugins/circuitbreaker/go.mod @@ -0,0 +1,36 @@ +module github.com/maximhq/bifrost/plugins/circuitbreaker + +go 1.24.1 + +toolchain go1.24.4 + +require github.com/maximhq/bifrost/core v1.1.6 + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.60.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/plugins/circuitbreaker/go.sum b/plugins/circuitbreaker/go.sum new file mode 100644 index 0000000000..5d87c56717 --- /dev/null +++ b/plugins/circuitbreaker/go.sum @@ -0,0 +1,58 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.3 h1:EBxwxqpCNNs3ck44qBwqrTiKGD+1Avyb57fM3/2wTKs= +github.com/maximhq/bifrost/core v1.1.3/go.mod h1:8ycaWQ9bjQezoUT/x6a82VmPjoqLzyGglQ0RnnlZjqo= +github.com/maximhq/bifrost/core v1.1.6 h1:rZrfPVcAfNggfBaOTdu/w+xNwDhW79bfexXsw8LRoMQ= +github.com/maximhq/bifrost/core v1.1.6/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= +github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= diff --git a/plugins/circuitbreaker/main.go b/plugins/circuitbreaker/main.go new file mode 100644 index 0000000000..82129b549f --- /dev/null +++ b/plugins/circuitbreaker/main.go @@ -0,0 +1,519 @@ +// Package circuitbreaker provides a circuit breaker plugin for the Bifrost system. +// The circuit breaker monitors request failures and slow calls to automatically +// open the circuit when thresholds are exceeded, preventing cascading failures. +// +// Configuration: +// The plugin accepts a CircuitBreakerConfig and automatically applies sensible defaults +// for any invalid or missing configuration values. This makes the plugin robust and +// user-friendly, as it will work even with incomplete or invalid configurations. +// +// Default Configuration Values: +// - FailureRateThreshold: 0.5 (50% failure rate threshold) +// - SlowCallRateThreshold: 0.5 (50% slow call rate threshold) +// - SlowCallDurationThreshold: 5 seconds +// - MinimumNumberOfCalls: 10 (minimum calls before evaluation) +// - SlidingWindowType: "count-based" +// - SlidingWindowSize: 100 (number of calls in window) +// - PermittedNumberOfCallsInHalfOpenState: 5 +// - MaxWaitDurationInHalfOpenState: 60 seconds +// +// Usage: +// +// config := CircuitBreakerConfig{ +// FailureRateThreshold: 0.3, // Only valid values need to be specified +// // Other values will use defaults +// } +// plugin, err := NewCircuitBreakerPlugin(config) +// +// The plugin will log any default values that were applied during initialization. +package circuitbreaker + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +const PluginName = "bifrost-circuit-breaker" + +// CircuitState represents the current state of a circuit breaker. +type CircuitState int32 + +const ( + StateClosed CircuitState = iota + StateOpen + StateHalfOpen +) + +// String returns the string representation of the circuit state +func (s CircuitState) String() string { + switch s { + case StateClosed: + return "CLOSED" + case StateOpen: + return "OPEN" + case StateHalfOpen: + return "HALF_OPEN" + default: + return "UNKNOWN" + } +} + +// SlidingWindowType defines the type of sliding window used for metrics collection. +type SlidingWindowType string + +const ( + CountBased SlidingWindowType = "count-based" + TimeBased SlidingWindowType = "time-based" +) + +// CallResult tracks the result of a single API call for circuit breaker evaluation. +type CallResult struct { + Duration time.Duration // Duration of the call + Success bool // Whether the call was successful + Timestamp time.Time // When the call was made + IsSlowCall bool // Whether the call exceeded the slow call threshold +} + +// SlidingWindow defines the interface for collecting and analyzing call metrics +// over a sliding window of time or count. +type SlidingWindow interface { + RecordCall(result CallResult) + GetMetrics() WindowMetrics + Reset() +} + +// WindowMetrics contains aggregated metrics for the sliding window. +type WindowMetrics struct { + TotalCalls int // Total number of calls in the window + FailedCalls int // Number of failed calls in the window + SlowCalls int // Number of slow calls in the window + FailureRate float64 // Failure rate as a percentage (0.0 to 1.0) + SlowCallRate float64 // Slow call rate as a percentage (0.0 to 1.0) +} + +// CountBasedWindow implements a count-based sliding window that maintains +// a fixed number of most recent call results. +type CountBasedWindow struct { + mu sync.RWMutex + calls []CallResult + maxSize int + position int + full bool +} + +// TimeBasedWindow implements a time-based sliding window that maintains +// call results within a specified time duration. +type TimeBasedWindow struct { + mu sync.RWMutex + calls []CallResult + windowDuration time.Duration + lastCleanup time.Time // Last time cleanup was performed + cleanupThreshold int // Number of calls before triggering cleanup + maxCallsBeforeCleanup int // Maximum calls to accumulate before forcing cleanup +} + +// ProviderCircuitState maintains the circuit breaker state for a specific provider. +// It uses atomic operations for thread-safe state transitions and mutex-protected +// sliding window access. +type ProviderCircuitState struct { + // Atomic variables for thread-safe access + state int32 // Current circuit state + stateTransitionTime int64 // Unix nano timestamp of last state transition + halfOpenCallsPermitted int32 // Number of calls allowed in half-open state + halfOpenCallsAttempted int32 // Number of calls attempted in half-open state + inFlightCalls int32 // Number of calls currently in progress + halfOpenSuccesses int32 // Number of successful calls in half-open state + + // Protected by mutex + mu sync.RWMutex + slidingWindow SlidingWindow +} + +// CircuitBreakerConfig contains all configuration parameters for the circuit breaker. +type CircuitBreakerConfig struct { + FailureRateThreshold float64 // Failure rate threshold (0.0 to 1.0) + SlowCallRateThreshold float64 // Slow call rate threshold (0.0 to 1.0) + SlowCallDurationThreshold time.Duration // Duration threshold for slow calls + MinimumNumberOfCalls int // Minimum calls before evaluation + SlidingWindowType SlidingWindowType // Type of sliding window + SlidingWindowSize int // Size of sliding window + PermittedNumberOfCallsInHalfOpenState int // Calls allowed in half-open state + MaxWaitDurationInHalfOpenState time.Duration // Wait time before half-open transition + Logger schemas.Logger // Logger for circuit breaker, use default logger if not provided +} + +// CircuitBreaker implements the Bifrost plugin interface to provide circuit breaker +// functionality. It maintains separate circuit states for each provider and uses +// sliding windows to track call metrics. +type CircuitBreaker struct { + config CircuitBreakerConfig + + // Per-provider circuit states + mu sync.RWMutex + providers map[schemas.ModelProvider]*ProviderCircuitState +} + +// CircuitBreakerMetrics provides observability data for a circuit breaker instance. +type CircuitBreakerMetrics struct { + State CircuitState // Current circuit state + FailureRate float64 // Current failure rate + SlowCallRate float64 // Current slow call rate + TotalCalls int // Total calls in window + FailedCalls int // Failed calls in window + SlowCalls int // Slow calls in window + StateTransitionTime time.Time // Time of last state transition + InFlightCalls int // Currently in-flight calls + HalfOpenCallsAttempted int // Calls attempted in half-open state + HalfOpenCallsPermitted int // Calls permitted in half-open state + HalfOpenSuccesses int // Successful calls in half-open state +} + +// Context keys for storing circuit breaker data in request context. +type contextKey string + +const ( + callStartTimeKey contextKey = "circuitbreaker_call_start_time" + circuitStateKey contextKey = "circuitbreaker_circuit_state" + providerKey contextKey = "circuitbreaker_provider" +) + +// NewCircuitBreakerPlugin creates a new circuit breaker plugin with the given configuration. +// It validates the configuration and uses default values for any invalid parameters. +func NewCircuitBreakerPlugin(config CircuitBreakerConfig) (*CircuitBreaker, error) { + // Apply default values for invalid configurations + validatedConfig, appliedDefaults := ValidateConfigWithDefaults(config) + + // Log any defaults that were applied using the configured logger + if len(appliedDefaults) > 0 && validatedConfig.Logger != nil { + validatedConfig.Logger.Info(fmt.Sprintf("Circuit breaker plugin: Applied %d default values", len(appliedDefaults))) + for _, defaultMsg := range appliedDefaults { + validatedConfig.Logger.Info(fmt.Sprintf(" - %s", defaultMsg)) + } + } + + return &CircuitBreaker{ + config: validatedConfig, + providers: make(map[schemas.ModelProvider]*ProviderCircuitState), + }, nil +} + +// GetName returns the plugin name for identification. +func (p *CircuitBreaker) GetName() string { + return PluginName +} + +// getOrCreateProviderState retrieves or creates the circuit breaker state for a provider. +// It uses a double-checked locking pattern to ensure thread-safe state creation. +func (p *CircuitBreaker) getOrCreateProviderState(provider schemas.ModelProvider) *ProviderCircuitState { + // Fast path: Try to get existing state with read lock + p.mu.RLock() + if state, exists := p.providers[provider]; exists { + p.mu.RUnlock() + return state + } + p.mu.RUnlock() + + // Slow path: Provider doesn't exist, need to create it + p.mu.Lock() + defer p.mu.Unlock() + + // Double-check: Another goroutine might have created it while we waited for the lock + if state, exists := p.providers[provider]; exists { + return state + } + + // Create new circuit breaker state for this provider + newState := &ProviderCircuitState{ + state: int32(StateClosed), + stateTransitionTime: time.Now().UnixNano(), + halfOpenCallsPermitted: int32(p.config.PermittedNumberOfCallsInHalfOpenState), + halfOpenCallsAttempted: 0, + inFlightCalls: 0, + halfOpenSuccesses: 0, + slidingWindow: p.createSlidingWindow(), + } + + p.providers[provider] = newState + return newState +} + +// createSlidingWindow creates a new sliding window based on the configuration type. +func (p *CircuitBreaker) createSlidingWindow() SlidingWindow { + switch p.config.SlidingWindowType { + case CountBased: + return newCountBasedWindow(p.config.SlidingWindowSize) + case TimeBased: + return newTimeBasedWindow(time.Duration(p.config.SlidingWindowSize) * time.Second) + default: + return newCountBasedWindow(p.config.SlidingWindowSize) + } +} + +// NewCountBasedWindow creates a new count-based sliding window with the specified size. +func newCountBasedWindow(size int) *CountBasedWindow { + return &CountBasedWindow{ + calls: make([]CallResult, size), + maxSize: size, + position: 0, + full: false, + } +} + +// NewTimeBasedWindow creates a new time-based sliding window with the specified duration. +func newTimeBasedWindow(duration time.Duration) *TimeBasedWindow { + return &TimeBasedWindow{ + calls: make([]CallResult, 0), + windowDuration: duration, + lastCleanup: time.Now(), + cleanupThreshold: 10, // Trigger cleanup every 10 calls + maxCallsBeforeCleanup: 100, // Force cleanup after 100 calls + } +} + +// GetProviderState safely retrieves an existing provider state (read-only). +// Returns the state and a boolean indicating if the provider exists. +func (p *CircuitBreaker) GetProviderState(provider schemas.ModelProvider) (*ProviderCircuitState, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + state, exists := p.providers[provider] + return state, exists +} + +// PreHook implements the Plugin interface and is called before each request. +// It checks the circuit breaker state and either allows the request to proceed +// or short-circuits with an error if the circuit is open or half-open limits are exceeded. +func (p *CircuitBreaker) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if req == nil { + return nil, nil, fmt.Errorf("request cannot be nil") + } + + provider := req.Provider + circuitState := p.getOrCreateProviderState(provider) + + // Get current state atomically + currentState := CircuitState(atomic.LoadInt32(&circuitState.state)) + + // Handle based on current state + switch currentState { + case StateOpen: + // Check if wait duration has passed + if !p.shouldTransitionToHalfOpen(circuitState) { + // Short-circuit with error - allow fallbacks to other providers + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Service temporarily unavailable: %s circuit breaker is OPEN due to high failure rate (%0.2f%%). Circuit will attempt recovery in %s. Please retry later or use an alternative provider.", provider, p.config.FailureRateThreshold*100, p.config.MaxWaitDurationInHalfOpenState), + Type: bifrost.Ptr("circuitbreaker_open"), + }, + }, + }, nil + } + // Transition to half-open and continue + p.transitionToHalfOpen(circuitState) + fallthrough + + case StateHalfOpen: + // Check if we're within permitted call limit + if !p.canMakeHalfOpenCall(circuitState) { + // Short-circuit with error - allow fallbacks to other providers + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: fmt.Sprintf("Service testing capacity: %s circuit breaker is in HALF_OPEN state with limited capacity (%d/%d calls). Please retry in a moment or use an alternative provider.", provider, atomic.LoadInt32(&circuitState.halfOpenCallsAttempted), atomic.LoadInt32(&circuitState.halfOpenCallsPermitted)), + Type: bifrost.Ptr("circuitbreaker_half_open_limit"), + }, + AllowFallbacks: nil, // Allow fallbacks by default + }, + }, nil + } + + case StateClosed: + // Allow request to proceed + } + + // Increment in-flight counter + atomic.AddInt32(&circuitState.inFlightCalls, 1) + + // Store call start time and circuit state in context + *ctx = context.WithValue(*ctx, callStartTimeKey, time.Now()) + *ctx = context.WithValue(*ctx, circuitStateKey, circuitState) + *ctx = context.WithValue(*ctx, providerKey, provider) + + return req, nil, nil +} + +// PostHook implements the Plugin interface and is called after each request. +// It records the call result, updates metrics, and evaluates state transitions +// based on the sliding window metrics. +func (p *CircuitBreaker) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + // Extract data from context + callStartTime := GetCallStartTime(*ctx) + circuitState := GetCircuitState(*ctx) + + if circuitState == nil { + // No circuit state found, return as-is + return result, err, nil + } + + // Calculate call duration + callDuration := time.Since(callStartTime) + + // Determine if this is a server error (status code 5xx) + isServerError := IsServerError(err) + // Determine if this is a rate limit exceeded error (status code 429) + isRateLimitExceeded := IsRateLimitExceeded(err) + + // Determine call result - count as failure for server errors (5xx) and rate limit exceeded (429) + // Client errors (4xx) and other errors are considered successful for circuit breaker purposes + callResult := CallResult{ + Duration: callDuration, + Success: (err == nil && result != nil) || (!isServerError && !isRateLimitExceeded), + Timestamp: callStartTime, + IsSlowCall: callDuration > p.config.SlowCallDurationThreshold, + } + + // Record call result in sliding window + circuitState.mu.Lock() + circuitState.slidingWindow.RecordCall(callResult) + metrics := circuitState.slidingWindow.GetMetrics() + circuitState.mu.Unlock() + + // Evaluate state transition based on current state + currentState := CircuitState(atomic.LoadInt32(&circuitState.state)) + newState := p.evaluateStateTransition(circuitState, currentState, metrics, callResult) + + // Perform state transition if needed + if newState != currentState { + p.transitionState(circuitState, newState) + } + + // Decrement in-flight counter + defer atomic.AddInt32(&circuitState.inFlightCalls, -1) + + return result, err, nil +} + +// shouldTransitionToHalfOpen checks if enough time has passed to transition from open to half-open. +func (p *CircuitBreaker) shouldTransitionToHalfOpen(state *ProviderCircuitState) bool { + transitionTime := atomic.LoadInt64(&state.stateTransitionTime) + waitDuration := p.config.MaxWaitDurationInHalfOpenState + return time.Since(time.Unix(0, transitionTime)) >= waitDuration +} + +// canMakeHalfOpenCall checks if we can make a call in half-open state. +func (p *CircuitBreaker) canMakeHalfOpenCall(state *ProviderCircuitState) bool { + permitted := atomic.LoadInt32(&state.halfOpenCallsPermitted) + // Atomically increment attempted counter and check if within permitted + attempted := atomic.AddInt32(&state.halfOpenCallsAttempted, 1) + return attempted <= permitted +} + +// evaluateStateTransition determines if a state transition should occur based on +// current metrics and the last call result. +func (p *CircuitBreaker) evaluateStateTransition(state *ProviderCircuitState, currentState CircuitState, metrics WindowMetrics, lastCall CallResult) CircuitState { + switch currentState { + case StateClosed: + // Check if failure rate or slow call rate exceeds thresholds + if metrics.TotalCalls >= p.config.MinimumNumberOfCalls { + if metrics.FailureRate >= p.config.FailureRateThreshold || + metrics.SlowCallRate >= p.config.SlowCallRateThreshold { + return StateOpen + } + } + return StateClosed + + case StateHalfOpen: + // If last call failed (server error) or was slow, go back to open + if !lastCall.Success || lastCall.IsSlowCall { + return StateOpen + } + + // If we've made all permitted calls successfully, close circuit + if lastCall.Success && !lastCall.IsSlowCall { + atomic.AddInt32(&state.halfOpenSuccesses, 1) + } + successes := atomic.LoadInt32(&state.halfOpenSuccesses) + if successes >= int32(p.config.PermittedNumberOfCallsInHalfOpenState) { + return StateClosed + } + return StateHalfOpen + + case StateOpen: + // Should only transition via shouldTransitionToHalfOpen check + return StateOpen + } + + return currentState +} + +// transitionState performs the actual state transition and resets relevant counters. +func (p *CircuitBreaker) transitionState(state *ProviderCircuitState, newState CircuitState) { + atomic.StoreInt32(&state.state, int32(newState)) + atomic.StoreInt64(&state.stateTransitionTime, time.Now().UnixNano()) + + // Reset counters based on new state + switch newState { + case StateClosed: + atomic.StoreInt32(&state.halfOpenCallsAttempted, 0) + state.mu.Lock() + state.slidingWindow.Reset() + state.mu.Unlock() + case StateOpen: + atomic.StoreInt32(&state.halfOpenCallsAttempted, 0) + case StateHalfOpen: + atomic.StoreInt32(&state.halfOpenCallsAttempted, 0) + } + + atomic.StoreInt32(&state.halfOpenSuccesses, 0) +} + +// transitionToHalfOpen transitions from open to half-open state. +func (p *CircuitBreaker) transitionToHalfOpen(state *ProviderCircuitState) { + atomic.StoreInt32(&state.state, int32(StateHalfOpen)) + atomic.StoreInt64(&state.stateTransitionTime, time.Now().UnixNano()) + atomic.StoreInt32(&state.halfOpenCallsAttempted, 0) +} + +// GetMetrics returns metrics for a specific provider. +func (p *CircuitBreaker) GetMetrics(provider schemas.ModelProvider) (*CircuitBreakerMetrics, error) { + state, exists := p.GetProviderState(provider) + if !exists { + return nil, fmt.Errorf("provider %s not found", provider) + } + + state.mu.RLock() + metrics := state.slidingWindow.GetMetrics() + state.mu.RUnlock() + + return &CircuitBreakerMetrics{ + State: CircuitState(atomic.LoadInt32(&state.state)), + FailureRate: metrics.FailureRate, + SlowCallRate: metrics.SlowCallRate, + TotalCalls: metrics.TotalCalls, + FailedCalls: metrics.FailedCalls, + SlowCalls: metrics.SlowCalls, + StateTransitionTime: time.Unix(0, atomic.LoadInt64(&state.stateTransitionTime)), + InFlightCalls: int(atomic.LoadInt32(&state.inFlightCalls)), + HalfOpenCallsAttempted: int(atomic.LoadInt32(&state.halfOpenCallsAttempted)), + HalfOpenCallsPermitted: int(atomic.LoadInt32(&state.halfOpenCallsPermitted)), + HalfOpenSuccesses: int(atomic.LoadInt32(&state.halfOpenSuccesses)), + }, nil +} + +// Cleanup implements the Plugin interface +func (p *CircuitBreaker) Cleanup() error { + p.mu.Lock() + defer p.mu.Unlock() + + // Clear all provider states + p.providers = make(map[schemas.ModelProvider]*ProviderCircuitState) + return nil +} + diff --git a/plugins/circuitbreaker/plugin_test.go b/plugins/circuitbreaker/plugin_test.go new file mode 100644 index 0000000000..b5c949babb --- /dev/null +++ b/plugins/circuitbreaker/plugin_test.go @@ -0,0 +1,442 @@ +package circuitbreaker + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestCircuitBreakerBasicFunctionality tests basic circuit breaker operations +func TestCircuitBreakerBasicFunctionality(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 100 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: CountBased, + SlidingWindowSize: 10, + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 1 * time.Second, + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + + // Test initial state + if cb.GetName() != PluginName { + t.Errorf("Expected plugin name %s, got %s", PluginName, cb.GetName()) + } + + // Test default config + defaultCB, err := NewCircuitBreakerPlugin(DefaultConfig()) + if err != nil { + t.Fatalf("Failed to create default circuit breaker: %v", err) + } + if defaultCB == nil { + t.Error("Default circuit breaker should not be nil") + } +} + +// TestCircuitBreakerStateTransitions tests state transitions from closed to open to half-open +func TestCircuitBreakerStateTransitions(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 100 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: CountBased, + SlidingWindowSize: 10, + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 100 * time.Millisecond, // Short for testing + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + provider := schemas.Ollama + + // Clean up any existing state + if err := cb.Cleanup(); err != nil { + t.Fatalf("Failed to cleanup: %v", err) + } + + // Test initial state should be closed + state := cb.GetState(provider) + if state != StateClosed { + t.Errorf("Expected initial state CLOSED, got %s", state) + } + + // Simulate failures to trigger open state + for i := 0; i < 5; i++ { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + // PreHook should succeed + _, _, err := cb.PreHook(&ctx, req) + if err != nil { + t.Errorf("PreHook failed on iteration %d: %v", i, err) + } + + // PostHook with server error (5xx) + serverError := &schemas.BifrostError{ + StatusCode: &[]int{500}[0], + Error: schemas.ErrorField{ + Message: "Internal Server Error", + }, + } + + _, _, err = cb.PostHook(&ctx, nil, serverError) + if err != nil { + t.Errorf("PostHook failed on iteration %d: %v", i, err) + } + + // Check state after each iteration + currentState := cb.GetState(provider) + t.Logf("State after iteration %d: %s", i, currentState) + } + + // Make one more failure to ensure circuit is open + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + // This PreHook should fail because circuit is now open + _, shortCircuit, err := cb.PreHook(&ctx, req) + if err != nil { + t.Errorf("PreHook should not return error, got: %v", err) + } + if shortCircuit == nil { + t.Error("Expected PreHook to return PluginShortCircuit when circuit is open") + } else if shortCircuit.Error == nil { + t.Error("Expected PluginShortCircuit to contain error when circuit is open") + } else { + t.Logf("Got expected PluginShortCircuit when circuit is open: %v", shortCircuit.Error.Error.Message) + } + + // Check if circuit is now open + state = cb.GetState(provider) + if state != StateOpen { + t.Errorf("Expected state OPEN after failures, got %s", state) + } + + // Test that requests are blocked when circuit is open (immediately after opening) + ctx = context.Background() + req = &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + // Check state before trying to make request + stateBefore := cb.GetState(provider) + t.Logf("State before blocked request: %s", stateBefore) + + _, shortCircuit, err = cb.PreHook(&ctx, req) + if err != nil { + t.Errorf("PreHook should not return error, got: %v", err) + } + if shortCircuit == nil { + t.Error("Expected PluginShortCircuit when circuit is open") + } else if shortCircuit.Error == nil { + t.Error("Expected PluginShortCircuit to contain error when circuit is open") + } else { + t.Logf("Got expected PluginShortCircuit when circuit is open: %v", shortCircuit.Error.Error.Message) + } + + // Wait for half-open transition + time.Sleep(150 * time.Millisecond) + + // Test half-open state - should succeed now + _, shortCircuit, err = cb.PreHook(&ctx, req) + if err != nil { + t.Errorf("PreHook should succeed in half-open state: %v", err) + } + if shortCircuit != nil { + t.Error("Expected no PluginShortCircuit in half-open state when call is permitted") + } + + // PostHook with success + _, _, err = cb.PostHook(&ctx, &schemas.BifrostResponse{}, nil) + if err != nil { + t.Errorf("PostHook failed: %v", err) + } + + // Check if circuit is still half-open (need more successful calls to close) + state = cb.GetState(provider) + if state != StateHalfOpen { + t.Errorf("Expected state HALF_OPEN after one successful call, got %s", state) + } + + // Make one more successful call to close the circuit + ctx2 := context.Background() + req2 := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + _, shortCircuit, err = cb.PreHook(&ctx2, req2) + if err != nil { + t.Errorf("PreHook should succeed in half-open state: %v", err) + } + if shortCircuit != nil { + t.Error("Expected no PluginShortCircuit in half-open state when call is permitted") + } + + _, _, err = cb.PostHook(&ctx2, &schemas.BifrostResponse{}, nil) + if err != nil { + t.Errorf("PostHook failed: %v", err) + } + + // Now circuit should be closed + state = cb.GetState(provider) + if state != StateClosed { + t.Errorf("Expected state CLOSED after two successful calls, got %s", state) + } +} + +// TestCircuitBreakerRecovery tests recovery from open state to closed state +func TestCircuitBreakerRecovery(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 100 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: CountBased, + SlidingWindowSize: 10, + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 100 * time.Millisecond, + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + provider := schemas.Ollama + + // Clean up any existing state + if err := cb.Cleanup(); err != nil { + t.Fatalf("Failed to cleanup: %v", err) + } + + // First, open the circuit + for i := 0; i < 6; i++ { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + serverError := &schemas.BifrostError{ + StatusCode: &[]int{500}[0], + Error: schemas.ErrorField{ + Message: "Internal Server Error", + }, + } + cb.PostHook(&ctx, nil, serverError) + } + + // Verify circuit is open + state := cb.GetState(provider) + if state != StateOpen { + t.Errorf("Expected state OPEN, got %s", state) + } + + // Wait for half-open transition + time.Sleep(150 * time.Millisecond) + + // Make successful calls in half-open state + for range 2 { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + cb.PostHook(&ctx, &schemas.BifrostResponse{}, nil) + } + + // Verify circuit is now closed + state = cb.GetState(provider) + if state != StateClosed { + t.Errorf("Expected state CLOSED after recovery, got %s", state) + } +} + +// TestCircuitBreakerSlowCalls tests slow call detection +func TestCircuitBreakerSlowCalls(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 50 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: CountBased, + SlidingWindowSize: 10, + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 1 * time.Second, + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + provider := schemas.Ollama + + // Make slow calls + for range 6 { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + + // Simulate slow call by adding delay + time.Sleep(60 * time.Millisecond) + + cb.PostHook(&ctx, &schemas.BifrostResponse{}, nil) + } + + // Circuit should be open due to slow calls + state := cb.GetState(provider) + if state != StateOpen { + t.Errorf("Expected state OPEN with slow calls, got %s", state) + } +} + +// TestCircuitBreakerMetrics tests metrics collection +func TestCircuitBreakerMetrics(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 100 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: CountBased, + SlidingWindowSize: 10, + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 1 * time.Second, + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + provider := schemas.Ollama + + // Make some calls + for range 3 { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + cb.PostHook(&ctx, &schemas.BifrostResponse{}, nil) + } + + // Make some failures + for range 2 { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + serverError := &schemas.BifrostError{ + StatusCode: &[]int{500}[0], + Error: schemas.ErrorField{ + Message: "Internal Server Error", + }, + } + cb.PostHook(&ctx, nil, serverError) + } + + // Check metrics + metrics, err := cb.GetMetrics(provider) + if err != nil { + t.Errorf("Failed to get metrics: %v", err) + } + + if metrics.TotalCalls != 5 { + t.Errorf("Expected 5 total calls, got %d", metrics.TotalCalls) + } + + if metrics.FailedCalls != 2 { + t.Errorf("Expected 2 failed calls, got %d", metrics.FailedCalls) + } + + expectedFailureRate := 0.4 // 2/5 + if metrics.FailureRate != expectedFailureRate { + t.Errorf("Expected failure rate %f, got %f", expectedFailureRate, metrics.FailureRate) + } +} + +// TestCircuitBreakerTimeBasedWindow tests time-based sliding window +func TestCircuitBreakerTimeBasedWindow(t *testing.T) { + config := CircuitBreakerConfig{ + FailureRateThreshold: 0.5, + SlowCallRateThreshold: 0.5, + SlowCallDurationThreshold: 100 * time.Millisecond, + MinimumNumberOfCalls: 5, + SlidingWindowType: TimeBased, + SlidingWindowSize: 1, // 1 second window + PermittedNumberOfCallsInHalfOpenState: 2, + MaxWaitDurationInHalfOpenState: 1 * time.Second, + } + + cb, err := NewCircuitBreakerPlugin(config) + if err != nil { + t.Fatalf("Failed to create circuit breaker: %v", err) + } + provider := schemas.Ollama + + // Make calls within the time window + for range 5 { + ctx := context.Background() + req := &schemas.BifrostRequest{ + Provider: provider, + Model: "test-model", + } + + cb.PreHook(&ctx, req) + serverError := &schemas.BifrostError{ + StatusCode: &[]int{500}[0], + Error: schemas.ErrorField{ + Message: "Internal Server Error", + }, + } + cb.PostHook(&ctx, nil, serverError) + } + + // Circuit should be open + state := cb.GetState(provider) + if state != StateOpen { + t.Errorf("Expected state OPEN with time-based window, got %s", state) + } + + // Wait for window to expire + time.Sleep(1000 * time.Millisecond) + + // Check metrics - should be reset + metrics, err := cb.GetMetrics(provider) + if err != nil { + t.Errorf("Failed to get metrics: %v", err) + } + + if metrics.TotalCalls != 0 { + t.Errorf("Expected 0 total calls after window expiry, got %d", metrics.TotalCalls) + } +} diff --git a/plugins/circuitbreaker/utils.go b/plugins/circuitbreaker/utils.go new file mode 100644 index 0000000000..0ce002402f --- /dev/null +++ b/plugins/circuitbreaker/utils.go @@ -0,0 +1,299 @@ +package circuitbreaker + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// DefaultConfig returns a default circuit breaker configuration +func DefaultConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + FailureRateThreshold: 0.5, // 50% failure rate threshold + SlowCallRateThreshold: 0.5, // 50% slow call rate threshold + SlowCallDurationThreshold: 5 * time.Second, // 5 seconds + MinimumNumberOfCalls: 10, // Minimum 10 calls before evaluation + SlidingWindowType: CountBased, // Count-based sliding window + SlidingWindowSize: 100, // 100 calls window + PermittedNumberOfCallsInHalfOpenState: 5, // 5 calls in half-open state + MaxWaitDurationInHalfOpenState: 60 * time.Second, // 60 seconds wait time + } +} + +// ValidateConfigWithDefaults validates the configuration and returns information about what defaults were applied. +// It returns the validated config and a slice of strings describing what defaults were applied. +func ValidateConfigWithDefaults(config CircuitBreakerConfig) (CircuitBreakerConfig, []string) { + defaults := DefaultConfig() + validated := config + var appliedDefaults []string + + // Apply defaults for invalid values and track what was applied + if validated.FailureRateThreshold < 0 || validated.FailureRateThreshold > 1 { + validated.FailureRateThreshold = defaults.FailureRateThreshold + appliedDefaults = append(appliedDefaults, fmt.Sprintf("failure rate threshold set to default: %f", defaults.FailureRateThreshold)) + } + if validated.SlowCallRateThreshold < 0 || validated.SlowCallRateThreshold > 1 { + validated.SlowCallRateThreshold = defaults.SlowCallRateThreshold + appliedDefaults = append(appliedDefaults, fmt.Sprintf("slow call rate threshold set to default: %f", defaults.SlowCallRateThreshold)) + } + if validated.SlowCallDurationThreshold <= 0 { + validated.SlowCallDurationThreshold = defaults.SlowCallDurationThreshold + appliedDefaults = append(appliedDefaults, fmt.Sprintf("slow call duration threshold set to default: %v", defaults.SlowCallDurationThreshold)) + } + if validated.MinimumNumberOfCalls <= 0 { + validated.MinimumNumberOfCalls = defaults.MinimumNumberOfCalls + appliedDefaults = append(appliedDefaults, fmt.Sprintf("minimum number of calls set to default: %d", defaults.MinimumNumberOfCalls)) + } + if validated.SlidingWindowSize <= 0 { + validated.SlidingWindowSize = defaults.SlidingWindowSize + appliedDefaults = append(appliedDefaults, fmt.Sprintf("sliding window size set to default: %d", defaults.SlidingWindowSize)) + } + if validated.PermittedNumberOfCallsInHalfOpenState <= 0 { + validated.PermittedNumberOfCallsInHalfOpenState = defaults.PermittedNumberOfCallsInHalfOpenState + appliedDefaults = append(appliedDefaults, fmt.Sprintf("permitted calls in half-open state set to default: %d", defaults.PermittedNumberOfCallsInHalfOpenState)) + } + if validated.MaxWaitDurationInHalfOpenState <= 0 { + validated.MaxWaitDurationInHalfOpenState = defaults.MaxWaitDurationInHalfOpenState + appliedDefaults = append(appliedDefaults, fmt.Sprintf("max wait duration in half-open state set to default: %v", defaults.MaxWaitDurationInHalfOpenState)) + } + + // Validate sliding window type + if validated.SlidingWindowType != CountBased && validated.SlidingWindowType != TimeBased { + validated.SlidingWindowType = defaults.SlidingWindowType + appliedDefaults = append(appliedDefaults, fmt.Sprintf("sliding window type set to default: %s", defaults.SlidingWindowType)) + } + if validated.Logger == nil { + validated.Logger = bifrost.NewDefaultLogger(schemas.LogLevelInfo) + appliedDefaults = append(appliedDefaults, fmt.Sprintf("logger set to default with level: %s", schemas.LogLevelInfo)) + } + + return validated, appliedDefaults +} + +// GetState returns the current circuit breaker state for a provider +func (p *CircuitBreaker) GetState(provider schemas.ModelProvider) CircuitState { + state := p.getOrCreateProviderState(provider) + currentState := CircuitState(atomic.LoadInt32(&state.state)) + return currentState +} + +// ForceOpen forces the circuit breaker to open state for a provider +func (p *CircuitBreaker) ForceOpen(provider schemas.ModelProvider) error { + state, exists := p.GetProviderState(provider) + if !exists { + return fmt.Errorf("provider %s not found", provider) + } + + p.transitionState(state, StateOpen) + return nil +} + +// ForceClose forces the circuit breaker to closed state for a provider +func (p *CircuitBreaker) ForceClose(provider schemas.ModelProvider) error { + state, exists := p.GetProviderState(provider) + if !exists { + return fmt.Errorf("provider %s not found", provider) + } + + p.transitionState(state, StateClosed) + return nil +} + +// Reset resets the circuit breaker state for a provider +func (p *CircuitBreaker) Reset(provider schemas.ModelProvider) error { + state, exists := p.GetProviderState(provider) + if !exists { + return fmt.Errorf("provider %s not found", provider) + } + + // Reset to closed state and clear all metrics + p.transitionState(state, StateClosed) + return nil +} + +// IsServerError checks if a BifrostError represents a server error (5xx status code) +func IsServerError(bifrostErr *schemas.BifrostError) bool { + if bifrostErr == nil || bifrostErr.StatusCode == nil { + return false + } + statusCode := *bifrostErr.StatusCode + return statusCode >= 500 && statusCode < 600 +} + +// IsRateLimitExceeded checks if a BifrostError represents a "Too Many Requests" (429 status code) +func IsRateLimitExceeded(bifrostErr *schemas.BifrostError) bool { + if bifrostErr == nil || bifrostErr.StatusCode == nil { + return false + } + statusCode := *bifrostErr.StatusCode + return statusCode == 429 +} + +// RecordCall adds a new call result to the count-based sliding window. +func (w *CountBasedWindow) RecordCall(result CallResult) { + w.mu.Lock() + defer w.mu.Unlock() + + w.calls[w.position] = result + w.position = (w.position + 1) % w.maxSize + if !w.full && w.position == 0 { + w.full = true + } +} + +// GetMetrics calculates and returns metrics for the count-based sliding window. +func (w *CountBasedWindow) GetMetrics() WindowMetrics { + w.mu.RLock() + defer w.mu.RUnlock() + + var totalCalls, failedCalls, slowCalls int + callCount := w.maxSize + if !w.full { + callCount = w.position + } + + for i := 0; i < callCount; i++ { + call := w.calls[i] + totalCalls++ + if !call.Success { + failedCalls++ + } + if call.IsSlowCall { + slowCalls++ + } + } + + var failureRate, slowCallRate float64 + if totalCalls > 0 { + failureRate = float64(failedCalls) / float64(totalCalls) + slowCallRate = float64(slowCalls) / float64(totalCalls) + } + + return WindowMetrics{ + TotalCalls: totalCalls, + FailedCalls: failedCalls, + SlowCalls: slowCalls, + FailureRate: failureRate, + SlowCallRate: slowCallRate, + } +} + +// Reset clears all call data from the count-based sliding window. +func (w *CountBasedWindow) Reset() { + w.mu.Lock() + defer w.mu.Unlock() + + w.calls = make([]CallResult, w.maxSize) + w.position = 0 + w.full = false +} + +// RecordCall adds a new call result to the time-based sliding window. +// using periodic cleanup to improve performance for high-frequency calls. +func (w *TimeBasedWindow) RecordCall(result CallResult) { + w.mu.Lock() + defer w.mu.Unlock() + + w.calls = append(w.calls, result) + + // Trigger cleanup based on conditions: + // 1. If we've exceeded the maximum calls threshold + // 2. If we've reached the cleanup threshold and enough time has passed + // 3. If we've accumulated too many calls + if len(w.calls) >= w.maxCallsBeforeCleanup || + (len(w.calls) >= w.cleanupThreshold && time.Since(w.lastCleanup) > w.windowDuration/4) { + w.cleanupExpiredEntries() + } +} + +// cleanupExpiredEntries removes expired call results from the sliding window. +func (w *TimeBasedWindow) cleanupExpiredEntries() { + cutoffTime := time.Now().Add(-w.windowDuration) + trimIdx := 0 + + // Find the first call that's still within the window + for i, call := range w.calls { + if call.Timestamp.After(cutoffTime) { + trimIdx = i + break + } + } + + // Remove expired entries + if trimIdx > 0 { + w.calls = w.calls[trimIdx:] + } + + w.lastCleanup = time.Now() +} + + +// GetMetrics calculates and returns metrics for the time-based sliding window. +func (w *TimeBasedWindow) GetMetrics() WindowMetrics { + // Trigger cleanup if needed before calculating metrics + w.mu.Lock() + if len(w.calls) > 0 && (len(w.calls) >= w.maxCallsBeforeCleanup || time.Since(w.lastCleanup) > w.windowDuration/2) { + w.cleanupExpiredEntries() + } + w.mu.Unlock() + + w.mu.RLock() + defer w.mu.RUnlock() + + var totalCalls, failedCalls, slowCalls int + cutoffTime := time.Now().Add(-w.windowDuration) + + for _, call := range w.calls { + if call.Timestamp.After(cutoffTime) { + totalCalls++ + if !call.Success { + failedCalls++ + } + if call.IsSlowCall { + slowCalls++ + } + } + } + + var failureRate, slowCallRate float64 + if totalCalls > 0 { + failureRate = float64(failedCalls) / float64(totalCalls) + slowCallRate = float64(slowCalls) / float64(totalCalls) + } + + return WindowMetrics{ + TotalCalls: totalCalls, + FailedCalls: failedCalls, + SlowCalls: slowCalls, + FailureRate: failureRate, + SlowCallRate: slowCallRate, + } +} + +// Reset clears all call data from the time-based sliding window. +func (w *TimeBasedWindow) Reset() { + w.mu.Lock() + defer w.mu.Unlock() + + w.calls = make([]CallResult, 0) + w.lastCleanup = time.Now() +} + +// Utility functions for context extraction +func GetCallStartTime(ctx context.Context) time.Time { + if startTime, ok := ctx.Value(callStartTimeKey).(time.Time); ok { + return startTime + } + return time.Now() // Fallback +} + +func GetCircuitState(ctx context.Context) *ProviderCircuitState { + if state, ok := ctx.Value(circuitStateKey).(*ProviderCircuitState); ok { + return state + } + return nil +} diff --git a/tests/core-providers/go.mod b/tests/core-providers/go.mod index 1ed870d853..1bd65ee304 100644 --- a/tests/core-providers/go.mod +++ b/tests/core-providers/go.mod @@ -3,7 +3,7 @@ module github.com/maximhq/bifrost/tests/core-providers go 1.24.1 require ( - github.com/maximhq/bifrost/core v1.1.5 + github.com/maximhq/bifrost/core v1.1.6 github.com/stretchr/testify v1.10.0 ) diff --git a/tests/core-providers/go.sum b/tests/core-providers/go.sum index baed34d884..e7dc099ad7 100644 --- a/tests/core-providers/go.sum +++ b/tests/core-providers/go.sum @@ -40,14 +40,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= -github.com/maximhq/bifrost/core v1.1.5 h1:Nm9XlS9Nso+pn+U5/btsJD8qRDYGQ1BBOjgqWT3PYSc= -github.com/maximhq/bifrost/core v1.1.5/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=