diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index 515d32c944..d8a001bed9 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -1400,6 +1400,29 @@ append_dynamic_columns_postgres() { echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-001';" >> "$output_file" echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-002';" >> "$output_file" fi + + # ------------------------------------------------------------------------- + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle + # ------------------------------------------------------------------------- + + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_postgres "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = false WHERE id = 1;" >> "$output_file" + fi + + # governance_model_pricing flex tier columns (added in v1.4.22) + if column_exists_postgres "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi } # Append dynamic column UPDATEs for columns that may not exist in older schemas (SQLite) @@ -2030,6 +2053,31 @@ append_dynamic_columns_sqlite() { # mcp_tool_logs.request_id (added in v1.4.21) echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-001';" >> "$output_file" echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-002';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_sqlite "$config_db" "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_model_pricing flex tier columns (added in v1.4.22) + if column_exists_sqlite "$config_db" "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi + fi } # ============================================================================ @@ -2933,6 +2981,10 @@ compare_postgres_snapshots() { if [ "$table" = "governance_budgets" ]; then dropped_columns="$dropped_columns calendar_aligned" fi + # enable_litellm_fallbacks (dropped from config_client in latest cut - behavior moved elsewhere) + if [ "$table" = "config_client" ]; then + dropped_columns="$dropped_columns enable_litellm_fallbacks" + fi local before_col_array IFS=',' read -ra before_col_array <<< "$before_columns" diff --git a/.gitignore b/.gitignore index b9119b3191..149b237ff4 100644 --- a/.gitignore +++ b/.gitignore @@ -121,4 +121,7 @@ terraform.tfstate.backup bifrost-benchmarking # Tests -:memory: \ No newline at end of file +:memory: + +# Generated test TLS certs (created by tests/docker-compose.yml redis-certs-init) +tests/redis-certs/ \ No newline at end of file diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..e51a878da8 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1,4 @@ +- refactor: split ModelRequested into OriginalModelRequested and ResolvedModelUsed for model alias tracking +- refactor: simplify Azure passthrough by removing redundant config nil checks +- refactor: simplify Mistral error parsing signature +- fix: carry ProviderResponseHeaders through text completion response conversion diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index b5d1d242dc..5912384bdb 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -212,7 +212,7 @@ const ( BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) - BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming, receives transport plugin logs - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) diff --git a/core/version b/core/version index 26ca594609..4cda8f19ed 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb2..4973abcc9b 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1,4 @@ +- feat: add MCP client discovered tools and tool name mapping migration +- fix: exception handling in async log store jobs +- refactor: model catalog Init API to use SetShouldSyncGate method +- refactor: rename DefaultPricingSyncInterval to DefaultSyncInterval diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go index e9d771fdde..b03afaa7ff 100644 --- a/framework/configstore/migrations_test.go +++ b/framework/configstore/migrations_test.go @@ -3,6 +3,7 @@ package configstore import ( "bytes" "context" + "encoding/json" "fmt" "log" "os" @@ -11,6 +12,8 @@ import ( "testing" "time" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" "github.com/stretchr/testify/assert" @@ -1090,3 +1093,912 @@ func TestMigrationDropDeploymentColumnsAndAddAliases_BedrockEncrypted(t *testing assert.Equal(t, "dep-claude", aliases["claude"]) assert.Equal(t, "dep-instant", aliases["claude-instant"]) } + +// ============================================================================ +// Helper: full migration DB setup +// ============================================================================ + +// setupFullMigrationDB creates a fresh in-memory SQLite database and runs the +// full triggerMigrations chain (the same code path as production startup). +// Returns an RDBConfigStore for CRUD verification and the raw *gorm.DB for +// low-level assertions. +// testDBCounter ensures each test gets a unique shared-cache in-memory SQLite URI. +var testDBCounter int64 + +func setupFullMigrationDB(t *testing.T) (*RDBConfigStore, *gorm.DB) { + t.Helper() + // Use a unique shared-cache URI so all connections see the same in-memory DB + // without requiring MaxOpenConns(1) (which deadlocks when code opens transactions + // and queries on s.db concurrently). + n := time.Now().UnixNano() + testDBCounter + testDBCounter++ + dsn := fmt.Sprintf("file:testdb_%d?mode=memory&cache=shared", n) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err, "Failed to create test database") + + ctx := context.Background() + err = triggerMigrations(ctx, db) + require.NoError(t, err, "triggerMigrations should succeed on a fresh DB") + + store := &RDBConfigStore{ + db: db, + logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + } + return store, db +} + +// ============================================================================ +// Part 2: Full-chain integration tests +// ============================================================================ + +func TestTriggerMigrations_FreshDB(t *testing.T) { + _, db := setupFullMigrationDB(t) + + // Every critical table should exist after the full migration chain. + criticalTables := []interface{}{ + &tables.TableProvider{}, + &tables.TableKey{}, + &tables.TableVirtualKey{}, + &tables.TableMCPClient{}, + &tables.TableBudget{}, + &tables.TableRateLimit{}, + &tables.TablePlugin{}, + &tables.TableCustomer{}, + &tables.TableTeam{}, + &tables.SessionsTable{}, + &tables.TableOauthConfig{}, + &tables.TableOauthToken{}, + &tables.TableModelPricing{}, + &tables.TableGovernanceConfig{}, + &tables.TableClientConfig{}, + &tables.TableVirtualKeyProviderConfig{}, + &tables.TableVirtualKeyMCPConfig{}, + } + + migrator := db.Migrator() + for _, table := range criticalTables { + assert.True(t, migrator.HasTable(table), "table should exist: %T", table) + } +} + +func TestTriggerMigrations_Idempotent(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + + ctx := context.Background() + + // First run + err = triggerMigrations(ctx, db) + require.NoError(t, err, "first triggerMigrations should succeed") + + // Second run – must be a no-op, not an error. + err = triggerMigrations(ctx, db) + require.NoError(t, err, "second triggerMigrations should succeed (idempotent)") + + // Tables should still be intact + assert.True(t, db.Migrator().HasTable(&tables.TableProvider{}), "TableProvider should still exist") + assert.True(t, db.Migrator().HasTable(&tables.TableKey{}), "TableKey should still exist") + assert.True(t, db.Migrator().HasTable(&tables.TableVirtualKey{}), "TableVirtualKey should still exist") +} + +func TestFullMigration_ProviderAndKeyCRUD(t *testing.T) { + if !encrypt.IsEnabled() { + t.Skip("encryption not enabled") + } + + store, db := setupFullMigrationDB(t) + ctx := context.Background() + + config := ProviderConfig{ + Keys: []schemas.Key{ + { + ID: "key-uuid-1", + Name: "openai-primary", + Value: *schemas.NewEnvVar("sk-test-secret-key-12345"), + Models: schemas.WhiteList{"*"}, + Weight: 1.0, + }, + }, + } + + err := store.AddProvider(ctx, "openai", config) + require.NoError(t, err) + + // Read back through store (AfterFind hooks decrypt) + result, err := store.GetProvidersConfig(ctx) + require.NoError(t, err) + require.Len(t, result, 1) + require.Contains(t, result, schemas.ModelProvider("openai")) + + openai := result[schemas.ModelProvider("openai")] + require.Len(t, openai.Keys, 1) + assert.Equal(t, "openai-primary", openai.Keys[0].Name) + assert.Equal(t, "sk-test-secret-key-12345", openai.Keys[0].Value.GetValue()) + + // Verify encryption at the raw DB level + var rawValue string + var rawStatus string + err = db.Table("config_keys").Select("value").Where("key_id = ?", "key-uuid-1").Scan(&rawValue).Error + require.NoError(t, err) + err = db.Table("config_keys").Select("encryption_status").Where("key_id = ?", "key-uuid-1").Scan(&rawStatus).Error + require.NoError(t, err) + + assert.NotEqual(t, "sk-test-secret-key-12345", rawValue, "key value should be encrypted at rest") + assert.Equal(t, "encrypted", rawStatus) +} + +func TestFullMigration_VirtualKeyCRUD(t *testing.T) { + if !encrypt.IsEnabled() { + t.Skip("encryption not enabled") + } + + store, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now() + + vk := &tables.TableVirtualKey{ + ID: "vk-test-001", + Name: "test-virtual-key", + Value: "vk-secret-value-12345", + IsActive: true, + CreatedAt: now, + UpdatedAt: now, + } + + err := store.CreateVirtualKey(ctx, vk) + require.NoError(t, err) + + // Read back + vks, err := store.GetVirtualKeys(ctx) + require.NoError(t, err) + require.Len(t, vks, 1) + + assert.Equal(t, "vk-test-001", vks[0].ID) + assert.Equal(t, "test-virtual-key", vks[0].Name) + assert.Equal(t, "vk-secret-value-12345", vks[0].Value) // AfterFind decrypts + assert.True(t, vks[0].IsActive) + + // Verify encryption at raw DB level + var rawValue, rawStatus, rawHash string + err = db.Table("governance_virtual_keys"). + Select("value").Where("id = ?", "vk-test-001").Scan(&rawValue).Error + require.NoError(t, err) + err = db.Table("governance_virtual_keys"). + Select("encryption_status").Where("id = ?", "vk-test-001").Scan(&rawStatus).Error + require.NoError(t, err) + err = db.Table("governance_virtual_keys"). + Select("value_hash").Where("id = ?", "vk-test-001").Scan(&rawHash).Error + require.NoError(t, err) + + assert.NotEqual(t, "vk-secret-value-12345", rawValue, "VK value should be encrypted at rest") + assert.Equal(t, "encrypted", rawStatus) + assert.NotEmpty(t, rawHash, "value_hash should be populated") +} + +func TestFullMigration_MCPClientCRUD(t *testing.T) { + if !encrypt.IsEnabled() { + t.Skip("encryption not enabled") + } + + store, db := setupFullMigrationDB(t) + ctx := context.Background() + + clientConfig := &schemas.MCPClientConfig{ + ID: "mcp-client-001", + Name: "test_mcp_server", + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: schemas.NewEnvVar("https://mcp.example.com/sse"), + ToolsToExecute: schemas.WhiteList{"*"}, + } + + err := store.CreateMCPClientConfig(ctx, clientConfig) + require.NoError(t, err) + + // Read back through store + mcpClient, err := store.GetMCPClientByName(ctx, "test_mcp_server") + require.NoError(t, err) + require.NotNil(t, mcpClient) + + assert.Equal(t, "mcp-client-001", mcpClient.ClientID) + assert.Equal(t, "test_mcp_server", mcpClient.Name) + assert.Equal(t, "sse", mcpClient.ConnectionType) + assert.Equal(t, "https://mcp.example.com/sse", mcpClient.ConnectionString.GetValue()) + + // Verify encryption at raw DB level + var rawConnStr, rawStatus string + err = db.Table("config_mcp_clients"). + Select("connection_string").Where("client_id = ?", "mcp-client-001").Scan(&rawConnStr).Error + require.NoError(t, err) + err = db.Table("config_mcp_clients"). + Select("encryption_status").Where("client_id = ?", "mcp-client-001").Scan(&rawStatus).Error + require.NoError(t, err) + + assert.NotEqual(t, "https://mcp.example.com/sse", rawConnStr, "connection string should be encrypted at rest") + assert.Equal(t, "encrypted", rawStatus) +} + +func TestFullMigration_EncryptPlaintextRows(t *testing.T) { + if !encrypt.IsEnabled() { + t.Skip("encryption not enabled") + } + + store, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now().UTC().Format("2006-01-02 15:04:05") + + // Insert a provider first (FK for config_keys) + err := db.Exec(`INSERT INTO config_providers (name, encryption_status, created_at, updated_at) + VALUES (?, 'plain_text', ?, ?)`, "openai", now, now).Error + require.NoError(t, err) + + // Get the provider ID + var providerID uint + err = db.Table("config_providers").Select("id").Where("name = ?", "openai").Scan(&providerID).Error + require.NoError(t, err) + + // Insert plaintext key (bypassing GORM hooks) + err = db.Exec(`INSERT INTO config_keys (name, provider_id, provider, key_id, value, models_json, + encryption_status, created_at, updated_at) + VALUES (?, ?, 'openai', ?, ?, '["*"]', 'plain_text', ?, ?)`, + "plaintext-key", providerID, "pk-1", "sk-plaintext-secret", now, now).Error + require.NoError(t, err) + + // Insert plaintext virtual key + err = db.Exec(`INSERT INTO governance_virtual_keys (id, name, value, is_active, encryption_status, created_at, updated_at) + VALUES (?, ?, ?, true, 'plain_text', ?, ?)`, + "vk-plain-1", "plaintext-vk", "vk-plain-secret", now, now).Error + require.NoError(t, err) + + // Verify they are plaintext before encryption + var preStatus string + err = db.Table("config_keys").Select("encryption_status").Where("key_id = ?", "pk-1").Scan(&preStatus).Error + require.NoError(t, err) + assert.Equal(t, "plain_text", preStatus) + + // Run the encryption pass + err = store.EncryptPlaintextRows(ctx) + require.NoError(t, err) + + // Verify raw DB: encryption_status changed, values differ + var rawKeyValue, rawKeyStatus string + err = db.Table("config_keys").Select("value").Where("key_id = ?", "pk-1").Scan(&rawKeyValue).Error + require.NoError(t, err) + err = db.Table("config_keys").Select("encryption_status").Where("key_id = ?", "pk-1").Scan(&rawKeyStatus).Error + require.NoError(t, err) + assert.NotEqual(t, "sk-plaintext-secret", rawKeyValue, "key should be encrypted after EncryptPlaintextRows") + assert.Equal(t, "encrypted", rawKeyStatus) + + var rawVKValue, rawVKStatus string + err = db.Table("governance_virtual_keys").Select("value").Where("id = ?", "vk-plain-1").Scan(&rawVKValue).Error + require.NoError(t, err) + err = db.Table("governance_virtual_keys").Select("encryption_status").Where("id = ?", "vk-plain-1").Scan(&rawVKStatus).Error + require.NoError(t, err) + assert.NotEqual(t, "vk-plain-secret", rawVKValue, "VK should be encrypted after EncryptPlaintextRows") + assert.Equal(t, "encrypted", rawVKStatus) + + // Verify GORM read decrypts correctly + var key tables.TableKey + err = db.Where("key_id = ?", "pk-1").First(&key).Error + require.NoError(t, err) + assert.Equal(t, "sk-plaintext-secret", key.Value.GetValue()) + + var vk tables.TableVirtualKey + err = db.Where("id = ?", "vk-plain-1").First(&vk).Error + require.NoError(t, err) + assert.Equal(t, "vk-plain-secret", vk.Value) +} + +func TestFullMigration_EndToEnd(t *testing.T) { + if !encrypt.IsEnabled() { + t.Skip("encryption not enabled") + } + + store, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now() + + // Add two providers with keys + for _, p := range []struct { + provider string + keyID string + keyName string + keyValue string + }{ + {"openai", "key-oa-1", "openai-key", "sk-openai-secret"}, + {"anthropic", "key-ant-1", "anthropic-key", "sk-anthropic-secret"}, + } { + err := store.AddProvider(ctx, schemas.ModelProvider(p.provider), ProviderConfig{ + Keys: []schemas.Key{{ + ID: p.keyID, + Name: p.keyName, + Value: *schemas.NewEnvVar(p.keyValue), + Models: schemas.WhiteList{"*"}, + Weight: 1.0, + }}, + }) + require.NoError(t, err, "AddProvider %s", p.provider) + } + + // Add virtual keys + for _, vk := range []struct { + id, name, value string + }{ + {"vk-1", "vk-alpha", "vk-alpha-secret"}, + {"vk-2", "vk-beta", "vk-beta-secret"}, + } { + err := store.CreateVirtualKey(ctx, &tables.TableVirtualKey{ + ID: vk.id, Name: vk.name, Value: vk.value, + IsActive: true, CreatedAt: now, UpdatedAt: now, + }) + require.NoError(t, err, "CreateVirtualKey %s", vk.name) + } + + // Add MCP client + err := store.CreateMCPClientConfig(ctx, &schemas.MCPClientConfig{ + ID: "mcp-e2e-1", + Name: "e2e_mcp_client", + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: schemas.NewEnvVar("https://mcp.e2e.test/sse"), + ToolsToExecute: schemas.WhiteList{"*"}, + }) + require.NoError(t, err) + + // Verify providers + providers, err := store.GetProvidersConfig(ctx) + require.NoError(t, err) + assert.Len(t, providers, 2) + assert.Contains(t, providers, schemas.ModelProvider("openai")) + assert.Contains(t, providers, schemas.ModelProvider("anthropic")) + assert.Equal(t, "sk-openai-secret", providers["openai"].Keys[0].Value.GetValue()) + assert.Equal(t, "sk-anthropic-secret", providers["anthropic"].Keys[0].Value.GetValue()) + + // Verify virtual keys + vks, err := store.GetVirtualKeys(ctx) + require.NoError(t, err) + assert.Len(t, vks, 2) + + // Verify MCP client + mcpClient, err := store.GetMCPClientByName(ctx, "e2e_mcp_client") + require.NoError(t, err) + assert.Equal(t, "https://mcp.e2e.test/sse", mcpClient.ConnectionString.GetValue()) + + // Verify all sensitive data is encrypted at the raw DB level + type encCheck struct { + table, column, whereCol, whereVal, plaintext string + } + checks := []encCheck{ + {"config_keys", "value", "key_id", "key-oa-1", "sk-openai-secret"}, + {"config_keys", "value", "key_id", "key-ant-1", "sk-anthropic-secret"}, + {"governance_virtual_keys", "value", "id", "vk-1", "vk-alpha-secret"}, + {"governance_virtual_keys", "value", "id", "vk-2", "vk-beta-secret"}, + {"config_mcp_clients", "connection_string", "client_id", "mcp-e2e-1", "https://mcp.e2e.test/sse"}, + } + for _, c := range checks { + var rawVal string + err := db.Table(c.table).Select(c.column).Where(fmt.Sprintf("%s = ?", c.whereCol), c.whereVal).Scan(&rawVal).Error + require.NoError(t, err) + assert.NotEqual(t, c.plaintext, rawVal, "raw %s.%s for %s=%s should be encrypted", + c.table, c.column, c.whereCol, c.whereVal) + } +} + +// ============================================================================ +// Part 3: Individual complex migration tests +// ============================================================================ + +// setupPreEncryptionDB runs all migrations up to (but not including) the +// encryption columns migration. This is approximated by running migrationInit +// plus the essential early migrations on a fresh SQLite DB. +func setupPreEncryptionDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + + ctx := context.Background() + // Run the initial migration to create core tables + err = migrationInit(ctx, db) + require.NoError(t, err) + // Sessions table is created by a later migration, but required by the + // encryption columns migration. + err = migrationAddSessionsTable(ctx, db) + require.NoError(t, err) + + return db +} + +func TestMigrationAddEncryptionColumns(t *testing.T) { + db := setupPreEncryptionDB(t) + ctx := context.Background() + now := time.Now().UTC().Format("2006-01-02 15:04:05") + + // Insert rows BEFORE encryption columns exist (they won't have encryption_status) + err := db.Exec(`INSERT INTO config_providers (name, created_at, updated_at) VALUES (?, ?, ?)`, + "openai", now, now).Error + require.NoError(t, err) + + var providerID uint + err = db.Table("config_providers").Select("id").Where("name = ?", "openai").Scan(&providerID).Error + require.NoError(t, err) + + err = db.Exec(`INSERT INTO config_keys (name, provider_id, provider, key_id, value, created_at, updated_at) + VALUES (?, ?, 'openai', ?, ?, ?, ?)`, + "test-key", providerID, "ek-1", "sk-test", now, now).Error + require.NoError(t, err) + + err = db.Exec(`INSERT INTO governance_virtual_keys (id, name, value, is_active, created_at, updated_at) + VALUES (?, ?, ?, true, ?, ?)`, + "vk-enc-1", "enc-vk", "vk-value", now, now).Error + require.NoError(t, err) + + err = db.Exec(`INSERT INTO sessions (token, created_at, updated_at, expires_at) + VALUES (?, ?, ?, ?)`, + "sess-token", now, now, now).Error + require.NoError(t, err) + + // Run the encryption columns migration + err = migrationAddEncryptionColumns(ctx, db) + require.NoError(t, err) + + // Verify encryption_status column exists and is backfilled on all 9 tables + encTables := []string{ + "config_keys", "governance_virtual_keys", "sessions", + "oauth_configs", "oauth_tokens", "config_mcp_clients", + "config_providers", "config_vector_store", "config_plugins", + } + for _, table := range encTables { + var count int64 + err := db.Table(table).Where("encryption_status = 'plain_text'").Count(&count).Error + if err != nil { + // Table might be empty, just verify column exists + assert.True(t, db.Migrator().HasColumn(table, "encryption_status"), + "encryption_status column should exist on %s", table) + continue + } + } + + // Verify pre-existing rows have encryption_status = 'plain_text' + var keyStatus string + err = db.Table("config_keys").Select("encryption_status").Where("key_id = ?", "ek-1").Scan(&keyStatus).Error + require.NoError(t, err) + assert.Equal(t, "plain_text", keyStatus) + + var providerStatus string + err = db.Table("config_providers").Select("encryption_status").Where("name = ?", "openai").Scan(&providerStatus).Error + require.NoError(t, err) + assert.Equal(t, "plain_text", providerStatus) + + var vkStatus string + err = db.Table("governance_virtual_keys").Select("encryption_status").Where("id = ?", "vk-enc-1").Scan(&vkStatus).Error + require.NoError(t, err) + assert.Equal(t, "plain_text", vkStatus) + + // Verify value_hash on governance_virtual_keys is NULL (not empty string) + var rawHash *string + err = db.Table("governance_virtual_keys").Select("value_hash").Where("id = ?", "vk-enc-1").Scan(&rawHash).Error + require.NoError(t, err) + assert.Nil(t, rawHash, "value_hash should be NULL, not empty string") + + // Verify token_hash on sessions is NULL + var tokenHash *string + err = db.Table("sessions").Select("token_hash").Where("token = ?", "sess-token").Scan(&tokenHash).Error + require.NoError(t, err) + assert.Nil(t, tokenHash, "token_hash should be NULL, not empty string") + + // Idempotency: running again should not error + err = migrationAddEncryptionColumns(ctx, db) + require.NoError(t, err) +} + +func TestMigrationCleanupMCPClientToolsConfig(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + ctx := context.Background() + + // Create the MCP clients table using AutoMigrate + err = db.AutoMigrate(&tables.TableMCPClient{}) + require.NoError(t, err) + + // Create the migrations tracking table + err = db.Exec(`CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)`).Error + require.NoError(t, err) + + now := time.Now() + + // Insert clients with various tools_to_execute_json states + clients := []struct { + name string + tools string + }{ + {"client_empty_array", "[]"}, + {"client_empty_string", ""}, + {"client_wildcard", `["*"]`}, + {"client_specific", `["tool1","tool2"]`}, + } + + for i, c := range clients { + err := db.Exec(`INSERT INTO config_mcp_clients (client_id, name, connection_type, tools_to_execute_json, created_at, updated_at, encryption_status) + VALUES (?, ?, 'stdio', ?, ?, ?, 'plain_text')`, + fmt.Sprintf("client-%d", i), c.name, c.tools, now, now).Error + require.NoError(t, err) + } + + // Also insert one with NULL tools + err = db.Exec(`INSERT INTO config_mcp_clients (client_id, name, connection_type, tools_to_execute_json, created_at, updated_at, encryption_status) + VALUES (?, ?, 'stdio', NULL, ?, ?, 'plain_text')`, + "client-null", "client_null_tools", now, now).Error + require.NoError(t, err) + + // Run the cleanup migration + err = migrationCleanupMCPClientToolsConfig(ctx, db) + require.NoError(t, err) + + // Verify: empty/null → ["*"], existing values preserved + for _, tc := range []struct { + name string + expected string + }{ + {"client_empty_array", `["*"]`}, + {"client_empty_string", `["*"]`}, + {"client_null_tools", `["*"]`}, + {"client_wildcard", `["*"]`}, + {"client_specific", `["tool1","tool2"]`}, + } { + var tools string + err := db.Table("config_mcp_clients").Select("tools_to_execute_json"). + Where("name = ?", tc.name).Scan(&tools).Error + require.NoError(t, err) + assert.Equal(t, tc.expected, tools, "tools for %s", tc.name) + } +} + +func TestMigrationAddConfigHashColumn(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + ctx := context.Background() + + // Create migrations table + err = db.Exec(`CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)`).Error + require.NoError(t, err) + + // Create tables WITHOUT config_hash column to simulate pre-migration state + err = db.Exec(`CREATE TABLE config_providers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) NOT NULL UNIQUE, + network_config_json TEXT, + concurrency_buffer_json TEXT, + proxy_config_json TEXT, + custom_provider_config_json TEXT, + send_back_raw_request BOOLEAN DEFAULT 0, + send_back_raw_response BOOLEAN DEFAULT 0, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + encryption_status VARCHAR(20) DEFAULT 'plain_text' + )`).Error + require.NoError(t, err) + + err = db.Exec(`CREATE TABLE config_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL, + key_id VARCHAR(255) NOT NULL UNIQUE, + provider_id INTEGER NOT NULL, + provider VARCHAR(50), + value TEXT NOT NULL, + models_json TEXT, + weight REAL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + encryption_status VARCHAR(20) DEFAULT 'plain_text' + )`).Error + require.NoError(t, err) + + now := time.Now() + + // Insert provider and key without config_hash + err = db.Exec(`INSERT INTO config_providers (name, created_at, updated_at, encryption_status) + VALUES (?, ?, ?, 'plain_text')`, "openai", now, now).Error + require.NoError(t, err) + + var providerID uint + err = db.Table("config_providers").Select("id").Where("name = ?", "openai").Scan(&providerID).Error + require.NoError(t, err) + + err = db.Exec(`INSERT INTO config_keys (name, key_id, provider_id, provider, value, models_json, created_at, updated_at, encryption_status) + VALUES (?, ?, ?, 'openai', ?, '["*"]', ?, ?, 'plain_text')`, + "test-key", "ck-1", providerID, "sk-test-value", now, now).Error + require.NoError(t, err) + + // Run the migration + err = migrationAddConfigHashColumn(ctx, db) + require.NoError(t, err) + + // Verify config_hash column exists + assert.True(t, db.Migrator().HasColumn(&tables.TableProvider{}, "config_hash")) + assert.True(t, db.Migrator().HasColumn(&tables.TableKey{}, "config_hash")) + + // Verify hashes are non-empty + var providerHash string + err = db.Table("config_providers").Select("config_hash").Where("name = ?", "openai").Scan(&providerHash).Error + require.NoError(t, err) + assert.NotEmpty(t, providerHash, "provider config_hash should be backfilled") + + var keyHash string + err = db.Table("config_keys").Select("config_hash").Where("key_id = ?", "ck-1").Scan(&keyHash).Error + require.NoError(t, err) + assert.NotEmpty(t, keyHash, "key config_hash should be backfilled") + + // Verify provider hash matches expected computation + expectedProviderConfig := ProviderConfig{} + expectedHash, err := expectedProviderConfig.GenerateConfigHash("openai") + require.NoError(t, err) + assert.Equal(t, expectedHash, providerHash, "provider hash should match GenerateConfigHash output") +} + +func TestMigrationBackfillEmptyVirtualKeyConfigs(t *testing.T) { + // Use full migration DB to get all tables, then insert test data + store, db := setupFullMigrationDB(t) + _ = store + ctx := context.Background() + now := time.Now() + + // Clear the migration tracking for the specific migration so it runs again + db.Exec(`DELETE FROM migrations WHERE id = 'backfill_empty_virtual_key_configs'`) + + // Create a provider + err := db.Exec(`INSERT INTO config_providers (name, encryption_status, created_at, updated_at) + VALUES ('openai', 'plain_text', ?, ?)`, now, now).Error + require.NoError(t, err) + + // Create a virtual key with NO provider configs + err = db.Exec(`INSERT INTO governance_virtual_keys (id, name, value, is_active, encryption_status, created_at, updated_at) + VALUES ('vk-backfill-1', 'backfill-vk', 'vk-value', true, 'plain_text', ?, ?)`, now, now).Error + require.NoError(t, err) + + // Verify no provider configs exist + var provConfigCount int64 + db.Table("governance_virtual_key_provider_configs").Where("virtual_key_id = ?", "vk-backfill-1").Count(&provConfigCount) + assert.Equal(t, int64(0), provConfigCount, "should have no provider configs before migration") + + // Run the migration + err = migrationBackfillEmptyVirtualKeyConfigs(ctx, db) + require.NoError(t, err) + + // Verify provider configs were created for the VK + db.Table("governance_virtual_key_provider_configs").Where("virtual_key_id = ?", "vk-backfill-1").Count(&provConfigCount) + assert.Equal(t, int64(1), provConfigCount, "should have 1 provider config after backfill (one for openai)") + + // Verify the provider config has correct defaults + var allowAllKeys bool + err = db.Table("governance_virtual_key_provider_configs"). + Select("allow_all_keys").Where("virtual_key_id = ?", "vk-backfill-1").Scan(&allowAllKeys).Error + require.NoError(t, err) + assert.True(t, allowAllKeys, "backfilled provider config should have AllowAllKeys=true") + + // Verify config_hash was recomputed (non-empty) + var vkHash string + err = db.Table("governance_virtual_keys").Select("config_hash"). + Where("id = ?", "vk-backfill-1").Scan(&vkHash).Error + require.NoError(t, err) + assert.NotEmpty(t, vkHash, "VK config_hash should be recomputed after backfill") +} + +func TestMigrationBackfillAllowedModelsWildcard(t *testing.T) { + _, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now() + + // Clear migration tracking so it runs again + db.Exec(`DELETE FROM migrations WHERE id = 'backfill_allowed_models_wildcard'`) + + // Create a provider + err := db.Exec(`INSERT INTO config_providers (name, encryption_status, created_at, updated_at) + VALUES ('openai', 'plain_text', ?, ?)`, now, now).Error + require.NoError(t, err) + + var providerID uint + err = db.Table("config_providers").Select("id").Where("name = ?", "openai").Scan(&providerID).Error + require.NoError(t, err) + + // Create a key with empty models_json + err = db.Exec(`INSERT INTO config_keys (name, key_id, provider_id, provider, value, models_json, + encryption_status, created_at, updated_at) + VALUES ('empty-models-key', 'emk-1', ?, 'openai', 'sk-test', '[]', 'plain_text', ?, ?)`, + providerID, now, now).Error + require.NoError(t, err) + + // Create a VK + err = db.Exec(`INSERT INTO governance_virtual_keys (id, name, value, is_active, encryption_status, created_at, updated_at) + VALUES ('vk-wildcard-1', 'wildcard-vk', 'vk-val', true, 'plain_text', ?, ?)`, now, now).Error + require.NoError(t, err) + + // Create a provider config with empty allowed_models + err = db.Exec(`INSERT INTO governance_virtual_key_provider_configs (virtual_key_id, provider, allowed_models, allow_all_keys) + VALUES ('vk-wildcard-1', 'openai', '[]', true)`).Error + require.NoError(t, err) + + // Run the migration + err = migrationBackfillAllowedModelsWildcard(ctx, db) + require.NoError(t, err) + + // Verify provider config allowed_models changed to ["*"] + var allowedModels string + err = db.Table("governance_virtual_key_provider_configs"). + Select("allowed_models").Where("virtual_key_id = ?", "vk-wildcard-1").Scan(&allowedModels).Error + require.NoError(t, err) + assert.Equal(t, `["*"]`, allowedModels, "allowed_models should be backfilled to wildcard") + + // Verify key models_json changed to ["*"] + var modelsJSON string + err = db.Table("config_keys").Select("models_json").Where("key_id = ?", "emk-1").Scan(&modelsJSON).Error + require.NoError(t, err) + assert.Equal(t, `["*"]`, modelsJSON, "models_json should be backfilled to wildcard") + + // Verify key config_hash was recomputed + var keyHash string + err = db.Table("config_keys").Select("config_hash").Where("key_id = ?", "emk-1").Scan(&keyHash).Error + require.NoError(t, err) + assert.NotEmpty(t, keyHash, "key config_hash should be recomputed") +} + +func TestMigrationRemoveServerPrefixFromMCPTools(t *testing.T) { + _, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now() + + // Clear migration tracking so it runs again + db.Exec(`DELETE FROM migrations WHERE id = 'remove_server_prefix_from_mcp_tools'`) + + // Insert an MCP client with prefixed tool names + toolsJSON, _ := json.Marshal([]string{"my_server_tool1", "my_server_tool2", "standalone_tool"}) + autoToolsJSON, _ := json.Marshal([]string{"my_server_auto1"}) + + err := db.Exec(`INSERT INTO config_mcp_clients (client_id, name, connection_type, + tools_to_execute_json, tools_to_auto_execute_json, encryption_status, created_at, updated_at) + VALUES (?, ?, 'stdio', ?, ?, 'plain_text', ?, ?)`, + "mcp-prefix-1", "my_server", string(toolsJSON), string(autoToolsJSON), now, now).Error + require.NoError(t, err) + + // Run the migration + err = migrationRemoveServerPrefixFromMCPTools(ctx, db) + require.NoError(t, err) + + // Verify tools had prefixes stripped + var resultToolsJSON string + err = db.Table("config_mcp_clients").Select("tools_to_execute_json"). + Where("client_id = ?", "mcp-prefix-1").Scan(&resultToolsJSON).Error + require.NoError(t, err) + + var resultTools []string + err = json.Unmarshal([]byte(resultToolsJSON), &resultTools) + require.NoError(t, err) + + assert.Contains(t, resultTools, "tool1", "should have stripped my_server_ prefix") + assert.Contains(t, resultTools, "tool2", "should have stripped my_server_ prefix") + assert.Contains(t, resultTools, "standalone_tool", "should preserve non-prefixed tool") + assert.NotContains(t, resultTools, "my_server_tool1", "original prefixed name should be gone") + + // Verify auto-execute tools + var resultAutoToolsJSON string + err = db.Table("config_mcp_clients").Select("tools_to_auto_execute_json"). + Where("client_id = ?", "mcp-prefix-1").Scan(&resultAutoToolsJSON).Error + require.NoError(t, err) + + var resultAutoTools []string + err = json.Unmarshal([]byte(resultAutoToolsJSON), &resultAutoTools) + require.NoError(t, err) + assert.Contains(t, resultAutoTools, "auto1", "should have stripped my_server_ prefix from auto tools") +} + +func TestMigrationRemoveServerPrefixFromMCPTools_Collision(t *testing.T) { + _, db := setupFullMigrationDB(t) + ctx := context.Background() + now := time.Now() + + // Clear migration tracking + db.Exec(`DELETE FROM migrations WHERE id = 'remove_server_prefix_from_mcp_tools'`) + + // Client where stripping the prefix would cause a collision: + // "srv_read" (prefixed) → "read", but "read" already exists in the list + toolsJSON, _ := json.Marshal([]string{"srv_read", "read"}) + err := db.Exec(`INSERT INTO config_mcp_clients (client_id, name, connection_type, + tools_to_execute_json, encryption_status, created_at, updated_at) + VALUES (?, ?, 'stdio', ?, 'plain_text', ?, ?)`, + "mcp-collision", "srv", string(toolsJSON), now, now).Error + require.NoError(t, err) + + // Run the migration — should not error, collision is handled + err = migrationRemoveServerPrefixFromMCPTools(ctx, db) + require.NoError(t, err) + + // The collision drops the duplicate, keeping "read" once + var resultJSON string + err = db.Table("config_mcp_clients").Select("tools_to_execute_json"). + Where("client_id = ?", "mcp-collision").Scan(&resultJSON).Error + require.NoError(t, err) + + var resultTools []string + err = json.Unmarshal([]byte(resultJSON), &resultTools) + require.NoError(t, err) + assert.Contains(t, resultTools, "read") + // Should not have duplicates + readCount := 0 + for _, tool := range resultTools { + if tool == "read" { + readCount++ + } + } + assert.Equal(t, 1, readCount, "should deduplicate on collision") +} + +func TestMigrationReplaceEnableLiteLLMWithCompatColumns(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + ctx := context.Background() + + // Create migrations table + err = db.Exec(`CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)`).Error + require.NoError(t, err) + + // Create config_client from the GORM model so DropColumn works (SQLite recreates + // the table from model columns). Then add the legacy column on top. + err = db.AutoMigrate(&tables.TableClientConfig{}) + require.NoError(t, err) + + // Add the legacy column that the migration will read and drop + err = db.Exec(`ALTER TABLE config_client ADD COLUMN enable_litellm_fallbacks BOOLEAN DEFAULT 0`).Error + require.NoError(t, err) + + now := time.Now() + + // Insert a row with enable_litellm_fallbacks = true + err = db.Exec(`INSERT INTO config_client (enable_litellm_fallbacks, created_at, updated_at) + VALUES (1, ?, ?)`, now, now).Error + require.NoError(t, err) + + // Insert a row with enable_litellm_fallbacks = false + err = db.Exec(`INSERT INTO config_client (enable_litellm_fallbacks, created_at, updated_at) + VALUES (0, ?, ?)`, now, now).Error + require.NoError(t, err) + + // Run the migration + err = migrationReplaceEnableLiteLLMWithCompatColumns(ctx, db) + require.NoError(t, err) + + // Verify new columns exist + mgr := db.Migrator() + assert.True(t, mgr.HasColumn(&tables.TableClientConfig{}, "compat_convert_text_to_chat")) + assert.True(t, mgr.HasColumn(&tables.TableClientConfig{}, "compat_convert_chat_to_responses")) + assert.True(t, mgr.HasColumn(&tables.TableClientConfig{}, "compat_should_drop_params")) + assert.True(t, mgr.HasColumn(&tables.TableClientConfig{}, "compat_should_convert_params")) + + // Verify data migration: row 1 had litellm=true → compat_convert_text_to_chat=true + type compatRow struct { + ID uint + CompatConvertTextToChat bool `gorm:"column:compat_convert_text_to_chat"` + CompatShouldConvertParams bool `gorm:"column:compat_should_convert_params"` + } + var rows []compatRow + err = db.Table("config_client"). + Select("id, compat_convert_text_to_chat, compat_should_convert_params"). + Order("id").Find(&rows).Error + require.NoError(t, err) + require.Len(t, rows, 2) + + assert.True(t, rows[0].CompatConvertTextToChat, "row with litellm=true should have compat_convert_text_to_chat=true") + assert.False(t, rows[0].CompatShouldConvertParams, "compat_should_convert_params should default to false") + + assert.False(t, rows[1].CompatConvertTextToChat, "row with litellm=false should have compat_convert_text_to_chat=false") + assert.False(t, rows[1].CompatShouldConvertParams, "compat_should_convert_params should default to false") +} + diff --git a/framework/configstore/rdb_test.go b/framework/configstore/rdb_test.go index 406e7a5cfd..4877dd02fc 100644 --- a/framework/configstore/rdb_test.go +++ b/framework/configstore/rdb_test.go @@ -42,6 +42,10 @@ func setupRDBTestStore(t *testing.T) *RDBConfigStore { &tables.TablePromptVersionMessage{}, &tables.TablePromptSession{}, &tables.TablePromptSessionMessage{}, + &tables.TablePerUserOAuthPendingFlow{}, + &tables.TablePerUserOAuthSession{}, + &tables.TableOauthUserSession{}, + &tables.TableOauthUserToken{}, ) require.NoError(t, err, "Failed to migrate test database") diff --git a/framework/logstore/asyncjob.go b/framework/logstore/asyncjob.go index 38173c6840..a28ba33a58 100644 --- a/framework/logstore/asyncjob.go +++ b/framework/logstore/asyncjob.go @@ -80,7 +80,9 @@ func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValu } // SubmitJob creates a pending job, starts background execution, and returns the job record. -func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, operation AsyncOperation, operationType schemas.RequestType) (*AsyncJob, error) { +// contextValues carries the original request's BifrostContext user values (virtual key, tracing +// headers, etc.) so they can be restored on the background execution context. +func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, operation AsyncOperation, operationType schemas.RequestType, contextValues map[any]any) (*AsyncJob, error) { if resultTTL <= 0 { resultTTL = DefaultAsyncJobResultTTL } @@ -109,15 +111,20 @@ func (e *AsyncJobExecutor) SubmitJob(virtualKeyValue *string, resultTTL int, ope return nil, fmt.Errorf("failed to create async job: %w", err) } - go e.executeJob(job.ID, job.ResultTTL, operation) + go e.executeJob(job.ID, job.ResultTTL, operation, contextValues) return job, nil } // executeJob runs the operation in the background and updates the job record. -func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation) { +func (e *AsyncJobExecutor) executeJob(jobID string, resultTTL int, operation AsyncOperation, contextValues map[any]any) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + // Restore original request context values (virtual key, tracing headers, etc.) + for k, v := range contextValues { + ctx.SetValue(k, v) + } + markFailed := func(msg string) { now := time.Now().UTC() expiresAt := now.Add(time.Duration(resultTTL) * time.Second) diff --git a/framework/logstore/asyncjob_test.go b/framework/logstore/asyncjob_test.go new file mode 100644 index 0000000000..3008a36c63 --- /dev/null +++ b/framework/logstore/asyncjob_test.go @@ -0,0 +1,220 @@ +package logstore + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +type asyncTestLogger struct{} + +func (asyncTestLogger) Debug(string, ...any) {} +func (asyncTestLogger) Info(string, ...any) {} +func (asyncTestLogger) Warn(string, ...any) {} +func (asyncTestLogger) Error(string, ...any) {} +func (asyncTestLogger) Fatal(string, ...any) {} +func (asyncTestLogger) SetLevel(schemas.LogLevel) {} +func (asyncTestLogger) SetOutputType(schemas.LoggerOutputType) {} +func (asyncTestLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder { + return schemas.NoopLogEvent +} + +type testGovernanceStore struct { + virtualKeys map[string]*configstoreTables.TableVirtualKey +} + +func (t *testGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { + vk, ok := t.virtualKeys[vkValue] + return vk, ok +} + +func newTestAsyncExecutor(t *testing.T) *AsyncJobExecutor { + t.Helper() + ctx := context.Background() + + store, err := newSqliteLogStore(ctx, &SQLiteConfig{Path: ":memory:"}, asyncTestLogger{}) + require.NoError(t, err) + t.Cleanup(func() { store.Close(ctx) }) + + govStore := &testGovernanceStore{ + virtualKeys: map[string]*configstoreTables.TableVirtualKey{ + "sk-bf-test": {ID: "vk-123", Value: "sk-bf-test"}, + }, + } + + return NewAsyncJobExecutor(store, govStore, asyncTestLogger{}) +} + +// waitForJobCompletion polls until the operation callback has been invoked. +func waitForJobCompletion(t *testing.T, done *atomic.Bool) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if done.Load() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for async job execution") +} + +// waitForJobStatus polls FindAsyncJobByID until the job reaches a non-pending +// status (or times out). This avoids a fragile time.Sleep between the operation +// callback completing and the DB update finishing. +func waitForJobStatus(t *testing.T, store LogStore, jobID string) *AsyncJob { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + job, err := store.FindAsyncJobByID(context.Background(), jobID) + if err == nil && job.Status != schemas.AsyncJobStatusPending { + return job + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for async job status update") + return nil +} + +func TestSubmitJob_PropagatesContextValues(t *testing.T) { + executor := newTestAsyncExecutor(t) + + // Simulate original request context values + contextValues := map[any]any{ + schemas.BifrostContextKeyVirtualKey: "sk-bf-test", + schemas.BifrostContextKey("x-bf-prom-env"): "production", + schemas.BifrostContextKey("x-bf-eh-custom"): "custom-value", + } + + var capturedCtx *schemas.BifrostContext + var done atomic.Bool + + operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { + capturedCtx = bgCtx + done.Store(true) + return map[string]string{"status": "ok"}, nil + } + + job, err := executor.SubmitJob(strPtr("sk-bf-test"), 3600, operation, schemas.ChatCompletionRequest, contextValues) + require.NoError(t, err) + require.NotNil(t, job) + + waitForJobCompletion(t, &done) + + assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey)) + assert.Equal(t, "production", capturedCtx.Value(schemas.BifrostContextKey("x-bf-prom-env"))) + assert.Equal(t, "custom-value", capturedCtx.Value(schemas.BifrostContextKey("x-bf-eh-custom"))) + assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest)) +} + +func TestSubmitJob_NilContextValues(t *testing.T) { + executor := newTestAsyncExecutor(t) + + var capturedCtx *schemas.BifrostContext + var done atomic.Bool + + operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { + capturedCtx = bgCtx + done.Store(true) + return map[string]string{"status": "ok"}, nil + } + + job, err := executor.SubmitJob(strPtr("sk-bf-test"), 3600, operation, schemas.ChatCompletionRequest, nil) + require.NoError(t, err) + require.NotNil(t, job) + + waitForJobCompletion(t, &done) + + assert.NotNil(t, capturedCtx) + assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest)) +} + +func TestSubmitJob_EmptyContextValues(t *testing.T) { + executor := newTestAsyncExecutor(t) + + var capturedCtx *schemas.BifrostContext + var done atomic.Bool + + operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { + capturedCtx = bgCtx + done.Store(true) + return map[string]string{"status": "ok"}, nil + } + + job, err := executor.SubmitJob(strPtr("sk-bf-test"), 3600, operation, schemas.ChatCompletionRequest, map[any]any{}) + require.NoError(t, err) + require.NotNil(t, job) + + waitForJobCompletion(t, &done) + + assert.NotNil(t, capturedCtx) + assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest)) +} + +func TestSubmitJob_AsyncFlagOverridesContextValues(t *testing.T) { + executor := newTestAsyncExecutor(t) + + // Pass context values that try to set BifrostIsAsyncRequest to false + contextValues := map[any]any{ + schemas.BifrostIsAsyncRequest: false, + } + + var capturedCtx *schemas.BifrostContext + var done atomic.Bool + + operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { + capturedCtx = bgCtx + done.Store(true) + return map[string]string{"status": "ok"}, nil + } + + job, err := executor.SubmitJob(strPtr("sk-bf-test"), 3600, operation, schemas.ChatCompletionRequest, contextValues) + require.NoError(t, err) + require.NotNil(t, job) + + waitForJobCompletion(t, &done) + + // BifrostIsAsyncRequest must be true — set AFTER restoring context values + assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest)) +} + +func TestSubmitJob_OperationFailure_PreservesContext(t *testing.T) { + executor := newTestAsyncExecutor(t) + + contextValues := map[any]any{ + schemas.BifrostContextKeyVirtualKey: "sk-bf-test", + } + + var capturedCtx *schemas.BifrostContext + var done atomic.Bool + + statusCode := fasthttp.StatusBadRequest + operation := func(bgCtx *schemas.BifrostContext) (interface{}, *schemas.BifrostError) { + capturedCtx = bgCtx + done.Store(true) + return nil, &schemas.BifrostError{ + StatusCode: &statusCode, + Error: &schemas.ErrorField{Message: "test error"}, + } + } + + job, err := executor.SubmitJob(strPtr("sk-bf-test"), 3600, operation, schemas.ChatCompletionRequest, contextValues) + require.NoError(t, err) + require.NotNil(t, job) + + waitForJobCompletion(t, &done) + + // Context values should still be available even when operation fails + assert.Equal(t, "sk-bf-test", capturedCtx.Value(schemas.BifrostContextKeyVirtualKey)) + assert.Equal(t, true, capturedCtx.Value(schemas.BifrostIsAsyncRequest)) + + // Verify job was marked as failed — poll until DB update completes + retrievedJob := waitForJobStatus(t, executor.logstore, job.ID) + assert.Equal(t, schemas.AsyncJobStatusFailed, retrievedJob.Status) +} diff --git a/framework/vectorstore/redis_test.go b/framework/vectorstore/redis_test.go index 335dfb4df3..fb742b71fb 100644 --- a/framework/vectorstore/redis_test.go +++ b/framework/vectorstore/redis_test.go @@ -3,15 +3,8 @@ package vectorstore import ( "bufio" "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "fmt" "io" - "math/big" "net" "os" "strconv" @@ -230,29 +223,12 @@ func TestRedisConfig_Validation(t *testing.T) { }, expectError: false, }, - { - name: "cluster mode with db 0", - config: RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), - ClusterMode: schemas.NewEnvVar("true"), - }, - expectError: false, - }, - { - name: "cluster mode rejects non-zero db", - config: RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), - DB: schemas.NewEnvVar("1"), - ClusterMode: schemas.NewEnvVar("true"), - }, - expectError: true, - errorMsg: "redis cluster mode does not support database selection", - }, { name: "tls enabled", config: RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6380"), - UseTLS: schemas.NewEnvVar("true"), + Addr: schemas.NewEnvVar("localhost:6380"), + UseTLS: schemas.NewEnvVar("true"), + InsecureSkipVerify: schemas.NewEnvVar("true"), }, expectError: false, }, @@ -278,33 +254,23 @@ func TestRedisConfig_Validation(t *testing.T) { } } -func validTestCertPEM(t *testing.T) string { +// readTestCACert loads the CA cert generated by the tests/docker-compose.yml +// redis-certs-init service. Tests requiring CA-verified TLS skip if the file +// isn't present (e.g. when docker compose hasn't been brought up). +func readTestCACert(t *testing.T) string { t.Helper() - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + data, err := os.ReadFile("../../tests/redis-certs/ca.crt") if err != nil { - t.Fatalf("generate key: %v", err) - } - template := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "test-ca"}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageCertSign, - BasicConstraintsValid: true, - IsCA: true, + t.Skipf("redis test CA cert not available; bring up tests/docker-compose.yml first: %v", err) } - certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) - if err != nil { - t.Fatalf("create certificate: %v", err) - } - return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + return string(data) } func TestNewRedisStore_ConfiguresStandaloneTLSClient(t *testing.T) { logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) store, err := newRedisStore(context.Background(), RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), + Addr: schemas.NewEnvVar("localhost:6380"), DB: schemas.NewEnvVar("2"), UseTLS: schemas.NewEnvVar("true"), InsecureSkipVerify: schemas.NewEnvVar("true"), @@ -322,10 +288,10 @@ func TestNewRedisStore_ConfiguresStandaloneTLSClientWithCACert(t *testing.T) { logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) store, err := newRedisStore(context.Background(), RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), + Addr: schemas.NewEnvVar("localhost:6380"), DB: schemas.NewEnvVar("2"), UseTLS: schemas.NewEnvVar("true"), - CACertPEM: schemas.NewEnvVar(validTestCertPEM(t)), + CACertPEM: schemas.NewEnvVar(readTestCACert(t)), }, logger) require.NoError(t, err) @@ -340,7 +306,7 @@ func TestNewRedisStore_ConfiguresClusterTLSClient(t *testing.T) { logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) store, err := newRedisStore(context.Background(), RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), + Addr: schemas.NewEnvVar("localhost:7100"), UseTLS: schemas.NewEnvVar("true"), InsecureSkipVerify: schemas.NewEnvVar("true"), ClusterMode: schemas.NewEnvVar("true"), @@ -349,7 +315,7 @@ func TestNewRedisStore_ConfiguresClusterTLSClient(t *testing.T) { client, ok := store.client.(*redis.ClusterClient) require.True(t, ok, "expected redis cluster client") - require.Equal(t, []string{"localhost:6379"}, client.Options().Addrs) + require.Equal(t, []string{"localhost:7100"}, client.Options().Addrs) require.NotNil(t, client.Options().TLSConfig) assert.True(t, client.Options().TLSConfig.InsecureSkipVerify) } @@ -358,9 +324,9 @@ func TestNewRedisStore_ConfiguresClusterTLSClientWithCACert(t *testing.T) { logger := bifrost.NewDefaultLogger(schemas.LogLevelInfo) store, err := newRedisStore(context.Background(), RedisConfig{ - Addr: schemas.NewEnvVar("localhost:6379"), + Addr: schemas.NewEnvVar("localhost:7100"), UseTLS: schemas.NewEnvVar("true"), - CACertPEM: schemas.NewEnvVar(validTestCertPEM(t)), + CACertPEM: schemas.NewEnvVar(readTestCACert(t)), ClusterMode: schemas.NewEnvVar("true"), }, logger) require.NoError(t, err) diff --git a/framework/version b/framework/version index 3a3cd8cc8b..1892b92676 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.3.1 +1.3.2 diff --git a/plugins/compat/changelog.md b/plugins/compat/changelog.md index ad3d633b71..00bd8899e7 100644 --- a/plugins/compat/changelog.md +++ b/plugins/compat/changelog.md @@ -1,2 +1 @@ -- feat: Adds option for converting chat completions to responses for models that support it -- feat: Adds option for dropping unsupported model parameters \ No newline at end of file +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/compat/version b/plugins/compat/version index 6c6aa7cb09..17e51c385e 100644 --- a/plugins/compat/version +++ b/plugins/compat/version @@ -1 +1 @@ -0.1.0 \ No newline at end of file +0.1.1 diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/governance/version b/plugins/governance/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/plugins/litellmcompat/changelog.md b/plugins/litellmcompat/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/litellmcompat/changelog.md +++ b/plugins/litellmcompat/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/litellmcompat/version b/plugins/litellmcompat/version index 17e51c385e..d917d3e26a 100644 --- a/plugins/litellmcompat/version +++ b/plugins/litellmcompat/version @@ -1 +1 @@ -0.1.1 +0.1.2 diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/logging/version b/plugins/logging/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/maxim/version b/plugins/maxim/version index 9c6d6293b1..fdd3be6df5 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.6.1 +1.6.2 diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/mocker/version b/plugins/mocker/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/otel/version b/plugins/otel/version index 6085e94650..23aa839063 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.2.1 +1.2.2 diff --git a/plugins/prompts/changelog.md b/plugins/prompts/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/prompts/changelog.md +++ b/plugins/prompts/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/prompts/version b/plugins/prompts/version index 7f207341d5..6d7de6e6ab 100644 --- a/plugins/prompts/version +++ b/plugins/prompts/version @@ -1 +1 @@ -1.0.1 \ No newline at end of file +1.0.2 diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index e69de29bb2..00bd8899e7 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.2 and framework to v1.3.2 diff --git a/plugins/telemetry/version b/plugins/telemetry/version index 26ca594609..4cda8f19ed 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.5.1 +1.5.2 diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index d86059bcab..e1c993062b 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -46,6 +46,126 @@ services: timeout: 10s retries: 3 + # Generates self-signed TLS certs shared with the TLS Redis services. + # Writes to ./redis-certs on the host so tests can read the CA cert. + redis-certs-init: + image: alpine:3.19 + volumes: + - ./redis-certs:/tls + command: > + sh -c " + set -e; + if [ ! -f /tls/redis.crt ]; then + apk add --no-cache openssl >/dev/null; + cd /tls; + openssl req -x509 -newkey rsa:2048 -days 365 -nodes -keyout ca.key -out ca.crt -subj '/CN=bifrost-test-ca'; + openssl req -new -newkey rsa:2048 -nodes -keyout redis.key -out redis.csr -subj '/CN=localhost' -addext 'subjectAltName=DNS:localhost,IP:127.0.0.1'; + openssl x509 -req -in redis.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out redis.crt -days 365 -extfile <(printf 'subjectAltName=DNS:localhost,IP:127.0.0.1'); + chmod 644 ca.crt ca.key redis.crt redis.key; + fi; + echo 'certs ready'; + " + networks: + - bifrost_network + + # TLS-enabled standalone Redis on 6380 for TLS client tests. + redis-tls: + image: redis:7.4-alpine + depends_on: + redis-certs-init: + condition: service_completed_successfully + volumes: + - ./redis-certs:/tls:ro + command: > + redis-server + --tls-port 6380 + --port 0 + --tls-cert-file /tls/redis.crt + --tls-key-file /tls/redis.key + --tls-ca-cert-file /tls/ca.crt + --tls-auth-clients no + --protected-mode no + ports: + - "6380:6380" + networks: + bifrost_network: + ipv4_address: 172.28.0.17 + + # Single-node Redis Cluster on 7000 for cluster client tests. + redis-cluster: + image: redis:7.4-alpine + entrypoint: ["sh", "-c"] + command: + - | + redis-server \ + --port 7000 \ + --bind 0.0.0.0 \ + --cluster-enabled yes \ + --cluster-config-file /tmp/nodes.conf \ + --cluster-announce-ip 127.0.0.1 \ + --cluster-announce-port 7000 \ + --cluster-announce-bus-port 17000 \ + --protected-mode no & + SERVER_PID=$$! + until redis-cli -p 7000 ping >/dev/null 2>&1; do sleep 0.2; done + redis-cli -p 7000 cluster addslotsrange 0 16383 || true + wait $$SERVER_PID + ports: + - "7000:7000" + - "17000:17000" + healthcheck: + test: ["CMD-SHELL", "redis-cli -p 7000 cluster info | grep -q cluster_state:ok"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 5s + networks: + bifrost_network: + ipv4_address: 172.28.0.18 + + # Single-node TLS Redis Cluster on 7100 for TLS cluster client tests. + redis-cluster-tls: + image: redis:7.4-alpine + depends_on: + redis-certs-init: + condition: service_completed_successfully + volumes: + - ./redis-certs:/tls:ro + entrypoint: ["sh", "-c"] + command: + - | + redis-server \ + --tls-port 7100 \ + --port 0 \ + --bind 0.0.0.0 \ + --tls-cert-file /tls/redis.crt \ + --tls-key-file /tls/redis.key \ + --tls-ca-cert-file /tls/ca.crt \ + --tls-auth-clients no \ + --tls-cluster yes \ + --cluster-enabled yes \ + --cluster-config-file /tmp/nodes.conf \ + --cluster-announce-ip 127.0.0.1 \ + --cluster-announce-tls-port 7100 \ + --cluster-announce-bus-port 17100 \ + --protected-mode no & + SERVER_PID=$$! + until redis-cli --tls --cacert /tls/ca.crt -p 7100 ping >/dev/null 2>&1; do sleep 0.2; done + redis-cli --tls --cacert /tls/ca.crt -p 7100 cluster addslotsrange 0 16383 || true + wait $$SERVER_PID + ports: + - "7100:7100" + - "17100:17100" + healthcheck: + test: ["CMD-SHELL", "redis-cli --tls --cacert /tls/ca.crt -p 7100 cluster info | grep -q cluster_state:ok"] + interval: 5s + timeout: 3s + retries: 10 + start_period: 5s + networks: + bifrost_network: + ipv4_address: 172.28.0.19 + # Qdrant instance for vector store tests qdrant: image: qdrant/qdrant:v1.16.0 diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index da2d27594c..06247d459c 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -128,6 +128,7 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { return h.client.TextCompletionRequest(bgCtx, bifrostTextReq) }, schemas.TextCompletionRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, err.Error()) @@ -166,6 +167,7 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { return h.client.ChatCompletionRequest(bgCtx, bifrostChatReq) }, schemas.ChatCompletionRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -204,6 +206,7 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { return h.client.ResponsesRequest(bgCtx, bifrostResponsesReq) }, schemas.ResponsesRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to create async job: %v", err)) @@ -238,6 +241,7 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { return h.client.EmbeddingRequest(bgCtx, bifrostEmbeddingReq) }, schemas.EmbeddingRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -276,6 +280,7 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { return h.client.SpeechRequest(bgCtx, bifrostSpeechReq) }, schemas.SpeechRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -314,6 +319,7 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { return h.client.TranscriptionRequest(bgCtx, bifrostTranscriptionReq) }, schemas.TranscriptionRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -352,6 +358,7 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { return h.client.ImageGenerationRequest(bgCtx, bifrostReq) }, schemas.ImageGenerationRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -390,6 +397,7 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { return h.client.ImageEditRequest(bgCtx, bifrostReq) }, schemas.ImageEditRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -423,6 +431,7 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { return h.client.ImageVariationRequest(bgCtx, bifrostReq) }, schemas.ImageVariationRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) @@ -456,6 +465,7 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { return h.client.RerankRequest(bgCtx, bifrostReq) }, schemas.RerankRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, err.Error()) @@ -489,6 +499,7 @@ func (h *AsyncHandler) asyncOCR(ctx *fasthttp.RequestCtx) { return h.client.OCRRequest(bgCtx, bifrostReq) }, schemas.OCRRequest, + bifrostCtx.GetUserValues(), ) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, err.Error()) diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 57d3d7a706..c5ea73e191 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -14,6 +14,8 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" + "time" "github.com/bytedance/sonic" "github.com/fasthttp/router" @@ -1667,8 +1669,17 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi // The streaming callback will complete the trace after the stream ends ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) - // Get the trace completer function for use in the streaming callback - traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Pre-allocate atomic.Value slot for the transport post-hook completer. + // TransportInterceptorMiddleware stores the completer into this slot after next(ctx) + // returns. The goroutine reads from the closure-captured pointer, avoiding any ctx + // access after the handler returns (fasthttp recycles RequestCtx). + var completerSlot atomic.Value + ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, &completerSlot) + + // Get the trace completer function for use in the streaming callback. + // Signature: func([]schemas.PluginLogEntry) — accepts transport plugin logs so it + // never needs to read from ctx.UserValue (ctx may be recycled). + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func([]schemas.PluginLogEntry)) // Get stream chunk interceptor for plugin hooks interceptor := h.config.GetStreamChunkInterceptor() @@ -1684,25 +1695,64 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi // Producer goroutine: processes the stream channel, formats SSE events, sends to reader go func() { - defer func() { - schemas.ReleaseHTTPRequest(httpReq) - // Retrieve and run transport post-hook completer before closing the stream - // so errors can still be communicated to the client as SSE events. - // Must retrieve here (not before goroutine) because TransportInterceptorMiddleware - // sets BifrostContextKeyTransportPostHookCompleter after next(ctx) returns. - if postHookCompleter, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(func() error); ok && postHookCompleter != nil { - if err := postHookCompleter(); err != nil { + var transportLogs []schemas.PluginLogEntry + completerRan := false + // runCompleter invokes the transport post-hook completer at most once. + // sendSSEOnError=true emits plugin errors as SSE "event: error" frames so the + // client sees them (happy path, before [DONE]); =false logs server-side only + // (early-return / defer fallback, after stream termination). + runCompleter := func(sendSSEOnError bool) { + if completerRan { + return + } + // Bounded wait for TransportInterceptorMiddleware to publish the completer. + // It calls slot.Store after next(ctx) returns, which races with this goroutine + // on fast/empty streams. 100ms is ample — the store runs a few instructions + // after the handler returns. + var loaded any + deadline := time.Now().Add(100 * time.Millisecond) + for { + if loaded = completerSlot.Load(); loaded != nil { + break + } + if time.Now().After(deadline) { + break + } + time.Sleep(time.Millisecond) + } + if loaded == nil { + return + } + postHookCompleter, ok := loaded.(func() ([]schemas.PluginLogEntry, error)) + if !ok { + return + } + completerRan = true + logs, err := postHookCompleter() + if err != nil { + if sendSSEOnError { errorJSON, marshalErr := sonic.Marshal(map[string]string{"error": err.Error()}) if marshalErr == nil { reader.SendError(errorJSON) } + } else { + logger.Warn("transport post-hook failed after stream terminated: %v", err) } } + transportLogs = logs + } + + defer func() { + schemas.ReleaseHTTPRequest(httpReq) + // Fallback: on early-return paths (client disconnect, interceptor error) + // we never reached the pre-[DONE] invocation, so run it now. Any error is + // logged server-side only — the stream is already closing. + runCompleter(false) reader.Done() - // Complete the trace after streaming finishes - // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + // Complete the trace after streaming finishes, passing transport plugin logs. + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL. if traceCompleter != nil { - traceCompleter() + traceCompleter(transportLogs) } }() @@ -1778,6 +1828,12 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bi } } + // Run the transport post-hook completer BEFORE the terminal [DONE] marker so + // that any plugin error can still be delivered to the client as an SSE event. + // Post-hooks emitted after [DONE] reach the wire but most clients stop reading + // once they see [DONE], so they'd be silently dropped. + runCompleter(true) + if !includeEventType && !skipDoneMarker { // Send the [DONE] marker to indicate the end of the stream (only for non-responses/image-gen APIs) if !reader.SendDone() { diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 830d1135f8..442f3a172e 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -357,10 +357,43 @@ func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddl // The streaming handler calls this BEFORE reader.Done() so that errors can // still be sent as SSE events. applyResponse=false because the response is // already on the wire and mutating ctx.Response would corrupt the chunked stream. + // + // IMPORTANT: The callback must NOT access ctx — fasthttp recycles RequestCtx + // after the response body stream completes. All needed data is eagerly captured + // here (while ctx is still valid) and passed through the closure. if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { - ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, func() error { - return runTransportPostHooks(ctx, plugins, bifrostCtx, false) - }) + // Verify the completer slot exists before allocating pooled snapshots. + // The streaming handler pre-allocates this *atomic.Value; if absent, + // skip work to avoid leaking pooled HTTPRequest/HTTPResponse objects. + slot, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(*atomic.Value) + if !ok { + return + } + + // Eagerly snapshot request/response from ctx before it can be recycled. + capturedReq := lib.BuildHTTPRequestFromFastHTTP(ctx) + capturedResp := lib.BuildHTTPResponseFromFastHTTP(ctx) + // Snapshot pre-hook transport plugin logs already accumulated on ctx. + var preHookLogs []schemas.PluginLogEntry + if logs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { + preHookLogs = logs + } + + completer := func() ([]schemas.PluginLogEntry, error) { + defer schemas.ReleaseHTTPRequest(capturedReq) + defer schemas.ReleaseHTTPResponse(capturedResp) + postHookLogs, err := runTransportPostHooksCaptured(capturedReq, capturedResp, plugins, bifrostCtx) + allLogs := preHookLogs + if len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, err + } + + // Store the completer in the atomic.Value slot that the streaming handler + // placed on ctx. The goroutine reads from its closure-captured copy of + // the slot, avoiding any ctx access after the handler returns. + slot.Store(completer) return } @@ -410,7 +443,7 @@ func runTransportPostHooks(ctx *fasthttp.RequestCtx, plugins []schemas.HTTPTrans if shouldApplyShortCircuit { applyHTTPResponseToCtx(ctx, httpResp) } - return fmt.Errorf("HTTPTransportPostHook plugin %s: %w", pluginName, err) + return fmt.Errorf("transport post-hook plugin %s: %w", pluginName, err) } } // Drain post-hook plugin logs and merge with pre-hook logs @@ -427,6 +460,58 @@ func runTransportPostHooks(ctx *fasthttp.RequestCtx, plugins []schemas.HTTPTrans return nil } +// runTransportPostHooksCaptured is the goroutine-safe variant of runTransportPostHooks. +// It uses pre-captured HTTPRequest and HTTPResponse snapshots instead of reading from +// a fasthttp RequestCtx, which may have been recycled by the time this runs in a +// streaming goroutine. Returns accumulated plugin logs (instead of writing them to +// ctx.UserValue) so the caller can forward them to the trace completer. +func runTransportPostHooksCaptured(capturedReq *schemas.HTTPRequest, capturedResp *schemas.HTTPResponse, plugins []schemas.HTTPTransportPlugin, bifrostCtx *schemas.BifrostContext) ([]schemas.PluginLogEntry, error) { + // Clone into fresh pooled objects so plugins can mutate without affecting the snapshots. + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = capturedReq.Method + req.Path = capturedReq.Path + for k, v := range capturedReq.Headers { + req.Headers[k] = v + } + for k, v := range capturedReq.Query { + req.Query[k] = v + } + for k, v := range capturedReq.PathParams { + req.PathParams[k] = v + } + + httpResp := schemas.AcquireHTTPResponse() + defer schemas.ReleaseHTTPResponse(httpResp) + httpResp.StatusCode = capturedResp.StatusCode + for k, v := range capturedResp.Headers { + httpResp.Headers[k] = v + } + + var allLogs []schemas.PluginLogEntry + + // Run http post-hooks in reverse order + for i := len(plugins) - 1; i >= 0; i-- { + plugin := plugins[i] + pluginName := plugin.GetName() + pluginCtx := bifrostCtx.WithPluginScope(&pluginName) + err := plugin.HTTPTransportPostHook(pluginCtx, req, httpResp) + pluginCtx.ReleasePluginScope() + if err != nil { + logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", pluginName, err.Error()) + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, fmt.Errorf("transport post-hook plugin %s: %w", pluginName, err) + } + } + // Drain post-hook plugin logs + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + allLogs = append(allLogs, postHookLogs...) + } + return allLogs, nil +} + // getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. func getBifrostContextFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.BifrostContext { return schemas.NewBifrostContext(ctx, schemas.NoDeadline) @@ -936,10 +1021,11 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { ctx.SetUserValue(schemas.BifrostContextKeyParentSpanID, parentSpanID) } - // Store a trace completion callback for streaming handlers to use - ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func() { - // Attach transport plugin logs before completing the trace (streaming path) - if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { + // Store a trace completion callback for streaming handlers to use. + // Accepts transport plugin logs as a parameter so it never reads from + // ctx.UserValue — ctx may be recycled by the time this runs in a goroutine. + ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func(transportLogs []schemas.PluginLogEntry) { + if len(transportLogs) > 0 { tracer.AttachPluginLogs(traceID, transportLogs) } tracer.CompleteAndFlushTrace(traceID) diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 88109fd448..f0b299daaa 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -1504,7 +1504,7 @@ func (g *GenericRouter) handleAsyncCreate( } } - job, err := executor.SubmitJob(vkValue, resultTTL, operation, operationType) + job, err := executor.SubmitJob(vkValue, resultTTL, operation, operationType, bifrostCtx.GetUserValues()) if err != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(err, "failed to create async job")) @@ -2314,8 +2314,11 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *sc // The streaming callback will complete the trace after the stream ends ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) - // Get the trace completer function for use in the streaming callback - traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Get the trace completer function for use in the streaming callback. + // Signature is func([]schemas.PluginLogEntry) so the callback never reads from + // ctx.UserValue (ctx may be recycled by fasthttp by the time this fires). + // Router path has no transport post-hook phase, so we always pass nil. + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func([]schemas.PluginLogEntry)) // Get stream chunk interceptor for plugin hooks interceptor := g.handlerStore.GetStreamChunkInterceptor() @@ -2338,7 +2341,7 @@ func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *sc // Complete the trace after streaming finishes // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL if traceCompleter != nil { - traceCompleter() + traceCompleter(nil) } }() diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 7f6127406d..84b08040fa 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -610,4 +610,17 @@ func BuildHTTPRequestFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPRequest // Note: Body not copied - for streaming, body was already consumed return req -} \ No newline at end of file +} + +// BuildHTTPResponseFromFastHTTP creates an HTTPResponse snapshot from fasthttp context. +// Only captures status code and headers — body is skipped because for streaming +// responses it is an active io.Reader that cannot be materialized. +// The returned response should be released with schemas.ReleaseHTTPResponse when done. +func BuildHTTPResponseFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.HTTPResponse { + resp := schemas.AcquireHTTPResponse() + resp.StatusCode = ctx.Response.StatusCode() + for key, value := range ctx.Response.Header.All() { + resp.Headers[string(key)] = string(value) + } + return resp +} diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb2..42a90326a1 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1,15 @@ +## ✨ Features + +- **OAuth MCP** - add next-step hints to OAuth MCP client creation response +- **Azure passthrough** - added azure passthrough support +- **272k token tier** - add 272k token tier pricing support in pricing +- **Flex and priority tier support** - added flex and priority tier support in pricing + +## 🐞 Fixed + +- **Streaming Post-Hook Race** — Fix race condition where fasthttp RequestCtx could be recycled before transport post-hooks complete in streaming goroutines; eagerly captures request/response snapshots before handler returns +- **Async User Values** — Propagate user values through all async inference handlers and job submissions +- **Trace Completer Safety** — Refactor trace completer to accept transport logs as parameter instead of reading from potentially recycled context +- **Async Log Store Exceptions** — Fix exception handling in async log store jobs +- **Model Alias Tracking** — Split ModelRequested into OriginalModelRequested and ResolvedModelUsed for accurate model alias resolution tracking +- **MCP Tool Discovery** — Add discovered tools and tool name mapping columns to MCP clients diff --git a/transports/version b/transports/version index 5e4f870eda..20ddf11d9e 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.5.0-prerelease2 +1.5.0-prerelease3